Unverified Commit 31034b4f authored by Chetan Kumar Verma's avatar Chetan Kumar Verma Committed by GitHub
Browse files

Update unit tests for HPU (#1682)

parent 29564ad6
......@@ -284,7 +284,8 @@ module_dict = {
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(device, module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_kbit_backprop(device, module, dtype):
b = 16
dim1 = 36
dim2 = 84
......@@ -298,24 +299,28 @@ def test_kbit_backprop(device, module):
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
if device == "hpu" and isinstance(kbit[1], bnb.nn.Linear4bit) and kbit[1].weight.quant_type == "fp4":
pytest.skip("FP4 is not supported on HPU")
if (
device == "hpu"
and isinstance(kbit[1], bnb.nn.Linear4bit)
and not is_supported_on_hpu(kbit[1].weight.quant_type, dtype)
):
pytest.skip("This configuration not supported on HPU")
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
kbit[1].weight.requires_grad_(False)
ref = ref.half().to(device)
kbit = kbit.half().to(device)
kbit = kbit.half().to(device)
ref = ref.to(device=device, dtype=dtype)
kbit = kbit.to(device=device, dtype=dtype)
kbit = kbit.to(device=device, dtype=dtype)
errs1 = []
errs2 = []
relerrs1 = []
relerrs2 = []
for i in range(100):
batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
batch = torch.randn(b, dim1, device=device, dtype=dtype)
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment