Commit dcecbb26 authored by Max Ryabinin's avatar Max Ryabinin
Browse files

Add force_no_igemmlt to test params

parent 24609b66
......@@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda",
list(product([False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda):
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
......@@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
......@@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment