Unverified Commit b7e60cac authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Improve torch.compile support for int8 with torch>=2.8 nightly (#1617)

parent 46442b03
...@@ -236,7 +236,8 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -236,7 +236,8 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.state = state ctx.state = state
ctx.grad_shape = input_shape ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A = A.dtype
ctx.dtype_bias = None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA, A) ctx.tensors = (CAt, subA, A)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment