"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "799c1cd21beff84e50ac4ab7a480e715780da2de"
Unverified Commit c8f564d5 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #868 from poedator/fix_1108

Fix for 4bit without compress_statistics.
parents bbbed83a 079d7afe
......@@ -642,7 +642,7 @@ class QuantState:
blocksize=qs_dict['blocksize'],
code=qs_dict['quant_map'].to(device),
dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']),
shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None,
offset=offset,
state2=state2,
)
......@@ -651,7 +651,7 @@ class QuantState:
def as_dict(self, packed=False):
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
"""
qs_dict = {
'quant_type': self.quant_type,
......@@ -659,19 +659,20 @@ class QuantState:
'blocksize': self.blocksize,
'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None,
'shape': tuple(self.shape),
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_blocksize': self.state2.blocksize,
'nested_quant_map': self.state2.code,
'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(),
})
if not packed:
return qs_dict
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
......
......@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
......@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
linear_q.weight = new_weight.to(device)
linear_q.weight = new_weight
if bias:
linear_q.bias.data = linear.bias.data.to(device)
linear_q.bias = torch.nn.Parameter(linear.bias)
linear_q = linear_q.to(device)
# saving to state_dict:
sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
......@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device, # TODO create on meta device to save loading time
device="meta",
)
# loading weights from state_dict:
linear_q2.weight = weight2.to(device)
linear_q2.weight = weight2
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
linear_q2 = linear_q2.to(device)
# MATCHING
a, b = linear_q.weight, linear_q2.weight
......@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
state_path_4bit
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment