Commit 55ebaac7 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Tests: don't require grad on weights for test_kbit_backprop

parent 318a86e3
......@@ -285,9 +285,6 @@ module_dict = {
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Test is not yet supported on CPU")
b = 16
dim1 = 36
dim2 = 84
......@@ -295,14 +292,15 @@ def test_kbit_backprop(device, module):
# dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
# ref[1].weight.requires_grad = False
torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight)
ref[1].weight.requires_grad_(False)
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
kbit[1].weight.requires_grad_(False)
ref = ref.half().to(device)
kbit = kbit.half().to(device)
kbit = kbit.half().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment