Commit 9851a10b authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added cast to fp4 layer for speed.

parent c93a90d0
...@@ -404,10 +404,10 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -404,10 +404,10 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA) ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx) ctx.tensor_states = (SCAt, state.idx)
else: else:
ctx.tensors = [None, None] ctx.tensors = [None, None, A]
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
...@@ -420,7 +420,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -420,7 +420,7 @@ class MatMul8bitLt(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
...@@ -436,6 +436,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -436,6 +436,7 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
#grad_B = torch.matmul(grad_output.t(), A)
CxAt, SAt = F.transform(CAt, formatB, transpose=True) CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
......
...@@ -190,7 +190,11 @@ class LinearFP4(nn.Linear): ...@@ -190,7 +190,11 @@ class LinearFP4(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
inp_dtype = x.dtype
x = x.to(torch.float16)
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment