Commit 7d1c9cfe authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

extra feats in constructor Params4bit

parent f1ef74f8
......@@ -168,8 +168,11 @@ class Params4bit(torch.nn.Parameter):
if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad,
cls.requires_grad = requires_grad
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
cls.blocksize = cls.quant_state.blocksize
cls.compress_statistics = cls.quant_state.nested
cls.quant_type = cls.quant_state.quant_type
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment