Commit 55ebaac7 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Tests: don't require grad on weights for test_kbit_backprop

parent 318a86e3
...@@ -285,9 +285,6 @@ module_dict = { ...@@ -285,9 +285,6 @@ module_dict = {
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(device, module): def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Test is not yet supported on CPU")
b = 16 b = 16
dim1 = 36 dim1 = 36
dim2 = 84 dim2 = 84
...@@ -295,14 +292,15 @@ def test_kbit_backprop(device, module): ...@@ -295,14 +292,15 @@ def test_kbit_backprop(device, module):
# dim2 = 83 # dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)]) ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
# ref[1].weight.requires_grad = False
torch.nn.init.kaiming_normal_(ref[0].weight) torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight) torch.nn.init.kaiming_normal_(ref[1].weight)
ref[1].weight.requires_grad_(False)
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)]) kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
kbit[0].weight.detach().copy_(ref[0].weight) kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight) kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias) kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias) kbit[1].bias.detach().copy_(ref[1].bias)
kbit[1].weight.requires_grad_(False)
ref = ref.half().to(device) ref = ref.half().to(device)
kbit = kbit.half().to(device) kbit = kbit.half().to(device)
kbit = kbit.half().to(device) kbit = kbit.half().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment