Unverified Commit 06029dd6 authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #1081 from akx/ruff-format

Reformat Python code with Ruff
parents fd723b78 5a4263f4
...@@ -7,15 +7,18 @@ from bitsandbytes.triton.triton_utils import is_triton_available ...@@ -7,15 +7,18 @@ from bitsandbytes.triton.triton_utils import is_triton_available
from tests.helpers import TRUE_FALSE from tests.helpers import TRUE_FALSE
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, @pytest.mark.skipif(
reason="This test requires triton and a GPU with compute capability 8.0 or higher.") not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.",
)
@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
def test_switchback(vector_wise_quantization): def test_switchback(vector_wise_quantization):
for dim in [83]: for dim in [83]:
for batch in [13]: for batch in [13]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half() standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() switchback = (
SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
)
baseline = Linear8bitLt(dim, 4 * dim).cuda().half() baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight) switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias) switchback.bias.data.copy_(standard.bias)
...@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization): ...@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization):
err_sb = (out_standard - out_sb).abs().mean() err_sb = (out_standard - out_sb).abs().mean()
err_baseline = (out_standard - out_baseline).abs().mean() err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline) print("OUT", err_sb, err_baseline)
assert err_sb < 2 * err_baseline assert err_sb < 2 * err_baseline
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
print('GW2', err_sb, err_baseline) print("GW2", err_sb, err_baseline)
assert err_sb < 2 * err_baseline assert err_sb < 2 * err_baseline
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
print('GW1', err_sb, err_baseline) print("GW1", err_sb, err_baseline)
assert err_sb < 2 * err_baseline assert err_sb < 2 * err_baseline
err_sb = (x1.grad - x2.grad).abs().mean() err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean() err_baseline = (x1.grad - x3.grad).abs().mean()
print('GX1', err_sb, err_baseline) print("GX1", err_sb, err_baseline)
assert err_sb < 2 * err_baseline assert err_sb < 2 * err_baseline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment