"projects/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "fbd3c679acb6e8ac61c0ca4bc1242a9e665d9ace"
Commit ffd46ce1 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

fixes for init and tests

parent 726f1470
...@@ -659,7 +659,7 @@ class QuantState: ...@@ -659,7 +659,7 @@ class QuantState:
'blocksize': self.blocksize, 'blocksize': self.blocksize,
'quant_map': self.code, 'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'), 'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None, 'shape': tuple(self.shape),
} }
if self.nested: if self.nested:
qs_dict.update({ qs_dict.update({
......
...@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device = "cuda" device = "cuda"
layer_shape = (300, 400) layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
# Quantizing original layer # Quantizing original layer
linear_q = bnb.nn.Linear4bit( linear_q = bnb.nn.Linear4bit(
...@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
compress_statistics=compress_statistics, compress_statistics=compress_statistics,
quant_type=quant_type, quant_type=quant_type,
device=device, device="meta", # TODO: consider both CPU, meta and CUDA creation
) )
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False) new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
linear_q.weight = new_weight.to(device) linear_q.weight = new_weight
if bias: if bias:
linear_q.bias.data = linear.bias.data.to(device) linear_q.bias = torch.nn.Parameter(linear.bias)
linear_q = linear_q.to(device)
# saving to state_dict: # saving to state_dict:
sd = linear_q.state_dict() sd = linear_q.state_dict()
# restoring from state_dict: # restoring from state_dict:
bias_data2 = sd.pop("bias", None) bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight") weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
linear.in_features, linear.in_features,
...@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
compress_statistics=compress_statistics, compress_statistics=compress_statistics,
quant_type=quant_type, quant_type=quant_type,
device=device, # TODO create on meta device to save loading time device="meta",
) )
# loading weights from state_dict: # loading weights from state_dict:
linear_q2.weight = weight2.to(device) linear_q2.weight = weight2
if bias: if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2) linear_q2.bias = torch.nn.Parameter(bias_data2)
linear_q2 = linear_q2.to(device)
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
...@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
state_path_4bit state_path_4bit
) )
size_ratio = size_4 / size_orig size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.285 target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg assert size_ratio < target_compression, ratio_error_msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment