Unverified Commit 49609323 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix torch.compile issue for LLM.int8() with threshold=0 (#1581)

parent 90bbe147
...@@ -84,6 +84,13 @@ def get_inverse_transform_indices( ...@@ -84,6 +84,13 @@ def get_inverse_transform_indices(
return permuted_tile_indices return permuted_tile_indices
# torch.compiler.is_compiling() is available only in torch >= 2.3
if hasattr(torch.compiler, "is_compiling"):
_is_compiling = torch.compiler.is_compiling
else:
_is_compiling = torch._dynamo.is_compiling
@deprecated( @deprecated(
"This function is deprecated and will be removed in a future release.", "This function is deprecated and will be removed in a future release.",
category=FutureWarning, category=FutureWarning,
...@@ -174,7 +181,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -174,7 +181,7 @@ class MatMul8bitLt(torch.autograd.Function):
input_shape = A.shape input_shape = A.shape
# Cast A to fp16 # Cast A to fp16
if A.dtype != torch.float16: if A.dtype != torch.float16 and not _is_compiling():
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
if len(A.shape) == 3: if len(A.shape) == 3:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment