Unverified Commit 46442b03 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Support 4bit torch.compile fullgraph with PyTorch nightly (#1616)

parent c244e983
...@@ -290,6 +290,13 @@ class Params4bit(torch.nn.Parameter): ...@@ -290,6 +290,13 @@ class Params4bit(torch.nn.Parameter):
return self return self
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
def _quantize(self, device): def _quantize(self, device):
w = self.data.contiguous().to(device) w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit( w_4bit, quant_state = bnb.functional.quantize_4bit(
...@@ -486,7 +493,7 @@ class Linear4bit(nn.Linear): ...@@ -486,7 +493,7 @@ class Linear4bit(nn.Linear):
bias = None if self.bias is None else self.bias.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype)
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment