Commit 0de1a449 authored by justheuristic's avatar justheuristic
Browse files

change order

parent e9b87112
......@@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0)
# Cast grad_output to fp16
grad_output_dtype = grad_output.dtype
......@@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function):
-1, grad_output.shape[-1]
).contiguous()
grad_A = grad_B = grad_bias = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
......@@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function):
else:
raise Exception('State must contain either CBt or CB matrix for backward')
if req_gradBias:
grad_bias = grad_output.sum(0)
# Cast grad_A back to grad_output_dtype
grad_output = grad_output.to(grad_output_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment