Commit 965fd5d5 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

test update

parent 4c11d6dc
...@@ -56,11 +56,11 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias): ...@@ -56,11 +56,11 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
compress_statistics=compress_statistics, compress_statistics=compress_statistics,
quant_type=quant_type, quant_type=quant_type,
device=device, device='meta',
) )
linear_q2.weight = weight2.to(device) linear_q2.weight = weight2.to(device)
if bias: if bias:
linear_q2.bias.data = bias_data2 linear_q2.bias = torch.nn.Parameter(bias_data2)
# matching # matching
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
...@@ -93,7 +93,7 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias): ...@@ -93,7 +93,7 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
assert torch.equal(a, b) assert torch.equal(a, b)
# Forward test # Forward test
x = torch.rand(42, linear_q.shape[-1], device=device) x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x) a = linear_q(x)
b = linear_q2(x) b = linear_q2(x)
assert a.device == b.device assert a.device == b.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment