Commit e2a75769 authored by dbaranchuk's avatar dbaranchuk
Browse files

bug fix

parent 4dd475ce
......@@ -388,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
grad_bias = grad_output.sum(0)
# Cast grad_A back to grad_output_dtype
grad_output.to(grad_output_dtype)
grad_output = grad_output.to(grad_output_dtype)
return grad_A, grad_B, None, grad_bias, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment