Commit 87f88af4 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Enable loading prequantized weights with bf16/fp16/fp32 quant_storage type for FSDP

parent 2621e1af
......@@ -273,6 +273,7 @@ class Params4bit(torch.nn.Parameter):
quantized_stats: Dict[str, Any],
requires_grad: bool = False,
device="cuda",
module: Optional["Linear4bit"] = None,
**kwargs,
) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
......@@ -284,6 +285,10 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = True
self.quant_storage = data.dtype
self.module = module
if self.module is not None:
self.module.quant_state = self.quant_state
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment