Commit ffd46ce1 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

fixes for init and tests

parent 726f1470
......@@ -659,7 +659,7 @@ class QuantState:
'blocksize': self.blocksize,
'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None,
'shape': tuple(self.shape),
}
if self.nested:
qs_dict.update({
......
......@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
......@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
device="meta", # TODO: consider both CPU, meta and CUDA creation
)
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
linear_q.weight = new_weight.to(device)
linear_q.weight = new_weight
if bias:
linear_q.bias.data = linear.bias.data.to(device)
linear_q.bias = torch.nn.Parameter(linear.bias)
linear_q = linear_q.to(device)
# saving to state_dict:
sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
......@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device, # TODO create on meta device to save loading time
device="meta",
)
# loading weights from state_dict:
linear_q2.weight = weight2.to(device)
linear_q2.weight = weight2
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
linear_q2 = linear_q2.to(device)
# MATCHING
a, b = linear_q.weight, linear_q2.weight
......@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
state_path_4bit
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment