Commit 7b764d35 authored by Mitchell Wortsman's avatar Mitchell Wortsman
Browse files

adding half() cast

parent 2489d819
......@@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype)
cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B)
......@@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
if req_gradB:
if fp8A.ndim == 3:
fp8At = fp8A.transpose(2, 1)
elif fp8A.ndim == 2:
fp8At = fp8A.t()
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
return grad_A, grad_B, None, None, None
......
......@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast
......@@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
......@@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
return out
class Linear8bitLtThresh(Linear8bitLt):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=6.0,
index=None,
):
super().__init__(
input_features,
output_features,
bias=bias,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
index=index
)
class LinearFP8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
......@@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
return out
class LinearInt8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
if self.bias is not None:
out += self.bias
return out
class LinearInt8Cast(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
if self.bias is not None:
out += self.bias
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment