"src/vscode:/vscode.git/clone" did not exist on "2f7a417d1fb11bd242ad7f9098bb9fdf77c54422"
Commit e9b87112 authored by justheuristic's avatar justheuristic
Browse files

un-fuse bias

parent 56a074f6
......@@ -316,15 +316,14 @@ class MatMul8bitLt(torch.autograd.Function):
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A_dtype)
delayed_bias = None
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A_dtype).add_(bias)
delayed_bias = bias
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
output.addmm_(subA, state.subB)
# 5. Save state
ctx.state = state
......@@ -341,6 +340,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
output = output.to(A_dtype)
if delayed_bias is not None:
output.add_(delayed_bias)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment