You need to sign in or sign up before continuing.
Unverified Commit b7e60cac authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Improve torch.compile support for int8 with torch>=2.8 nightly (#1617)

parent 46442b03
......@@ -236,7 +236,8 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.state = state
ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
ctx.dtype_A = A.dtype
ctx.dtype_bias = None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA, A)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment