Commit 965fd5d5 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

test update

parent 4c11d6dc
......@@ -56,11 +56,11 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
device='meta',
)
linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias.data = bias_data2
linear_q2.bias = torch.nn.Parameter(bias_data2)
# matching
a, b = linear_q.weight, linear_q2.weight
......@@ -93,7 +93,7 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
assert torch.equal(a, b)
# Forward test
x = torch.rand(42, linear_q.shape[-1], device=device)
x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x)
b = linear_q2(x)
assert a.device == b.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment