Commit de44be1a authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

unshared nested_quant_map

parent ffd46ce1
...@@ -651,7 +651,7 @@ class QuantState: ...@@ -651,7 +651,7 @@ class QuantState:
def as_dict(self, packed=False): def as_dict(self, packed=False):
""" """
returns dict of tensors and strings to use in serialization via _save_to_state_dict() returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
""" """
qs_dict = { qs_dict = {
'quant_type': self.quant_type, 'quant_type': self.quant_type,
...@@ -665,13 +665,14 @@ class QuantState: ...@@ -665,13 +665,14 @@ class QuantState:
qs_dict.update({ qs_dict.update({
'nested_absmax': self.state2.absmax, 'nested_absmax': self.state2.absmax,
'nested_blocksize': self.state2.blocksize, 'nested_blocksize': self.state2.blocksize,
'nested_quant_map': self.state2.code, 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
'nested_dtype': str(self.state2.dtype).strip('torch.'), 'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(), 'nested_offset': self.offset.item(),
}) })
if not packed: if not packed:
return qs_dict return qs_dict
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment