Unverified Commit d9333aa9 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Improvement for torch.compile support on Params4bit (#1673)

parent 11df723f
......@@ -291,13 +291,6 @@ class Params4bit(torch.nn.Parameter):
return self
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
def _quantize(self, device):
w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(
......
......@@ -270,10 +270,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
if device == "cpu" and quant_type == "fp4":
pytest.skip("FP4 is not supported for CPU")
if fullgraph and torch.__version__ < (2, 8):
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
pytest.skip("fullgraph mode requires torch 2.8 or higher")
if device == "cuda" and platform.system() == "Windows":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment