"...git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "35badc0892275c35818ca39800ec55d9c7342c8f"
Unverified Commit 7fed393a authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix restoration of quant_storage for CPU offloading (#1279)



* Fix restoration of quant_storage for CPU offloading

* Clarify comment on default quant_storage in Params4bit.from_prequantized()

* fix to make quant_storage dynamic based on serialized dtype

* delete obsolete comment

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent e3ae243b
...@@ -282,10 +282,13 @@ class Params4bit(torch.nn.Parameter): ...@@ -282,10 +282,13 @@ class Params4bit(torch.nn.Parameter):
self.compress_statistics = self.quant_state.nested self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type self.quant_type = self.quant_state.quant_type
self.bnb_quantized = True self.bnb_quantized = True
self.quant_storage = data.dtype
return self return self
def _quantize(self, device): def _quantize(self, device):
w = self.data.contiguous().cuda(device) w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit( w_4bit, quant_state = bnb.functional.quantize_4bit(
w, w,
blocksize=self.blocksize, blocksize=self.blocksize,
...@@ -333,6 +336,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -333,6 +336,7 @@ class Params4bit(torch.nn.Parameter):
blocksize=self.blocksize, blocksize=self.blocksize,
compress_statistics=self.compress_statistics, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_type=self.quant_type,
quant_storage=self.quant_storage,
) )
return new_param return new_param
...@@ -450,7 +454,7 @@ class Linear4bit(nn.Linear): ...@@ -450,7 +454,7 @@ class Linear4bit(nn.Linear):
# since we registered the module, we can recover the state here # since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1 assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit): if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
self.weight.quant_state = self.quant_state self.weight.quant_state = self.quant_state
else: else:
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment