Commit 0a0b531f authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

removed optional data=None in from_prequantized

parent 7a117e44
...@@ -154,10 +154,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -154,10 +154,7 @@ class Params4bit(torch.nn.Parameter):
return self return self
@classmethod @classmethod
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs): def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
if data is None:
weight_key = [k for k in quantized_stats if k.endswith(".weight")][0]
data = quantized_stats.pop(weight_key)
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment