"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "186bfd2eb9ee44db129d1dac70dcabe01461227b"
Commit 7f0773ae authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added backprop test for Linear8bitLt and LinearFP4.

parent c0c352b3
...@@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient():
@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
...@@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module): ...@@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module):
o1 = l1(b1) o1 = l1(b1)
assert l1.bias is None assert l1.bias is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_kbit_backprop(module):
b = 17
dim1 = 37
dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
ref[1].weight.requires_grad = False
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
for i in range(100):
batch = torch.randn(b, dim1).half().cuda()
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
out2.mean().backward()
grad1 = ref[0].weight.grad
grad2 = kbit[0].weight.grad
bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
assert kbit[0].weight.grad.sum().item() == 0
assert kbit[0].bias.grad.sum().item() == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment