Commit fca68194 authored by rusty1s's avatar rusty1s
Browse files

fix test on pytorch 1.6.0

parent 947e0369
...@@ -51,12 +51,12 @@ def test_spmm_half_precision(): ...@@ -51,12 +51,12 @@ def test_spmm_half_precision():
src_dense[:, 2:4] = 0 # Remove multiple columns. src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense) src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.half, device='cpu') other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')
expected = src_dense @ other expected = (src_dense.to(torch.float) @ other).to(torch.half)
out = src @ other out = src @ other.to(torch.half)
assert torch.allclose(expected, out, atol=1e-6) assert torch.allclose(expected, out, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment