"vscode:/vscode.git/clone" did not exist on "f3aff2fa60c110c25df671b6f99ffb26727cb8ae"
Commit fca68194 authored by rusty1s's avatar rusty1s
Browse files

fix test on pytorch 1.6.0

parent 947e0369
......@@ -51,12 +51,12 @@ def test_spmm_half_precision():
src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.half, device='cpu')
other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')
expected = src_dense @ other
out = src @ other
expected = (src_dense.to(torch.float) @ other).to(torch.half)
out = src @ other.to(torch.half)
assert torch.allclose(expected, out, atol=1e-6)
assert torch.allclose(expected, out, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment