Commit dcecbb26 authored by Max Ryabinin's avatar Max Ryabinin
Browse files

Add force_no_igemmlt to test params

parent 24609b66
...@@ -69,9 +69,9 @@ def test_linear_no_igemmlt(): ...@@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda", @pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True]))) list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda): def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half) x = torch.randn(3, 32, dtype=torch.half)
...@@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=6.0,
) )
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
) )
...@@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=6.0,
) )
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda: if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment