Unverified Commit 56ce60b0 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Fix SpMM and SDDMM test precision problem (#5060)

parent 1f2fcae3
......@@ -44,9 +44,10 @@ def test_spmm(create_func, shape, nnz, out_dim):
if out_dim is None:
torch_sparse_result = torch_sparse_result.view(-1)
torch_sparse_result.backward(grad)
assert torch.allclose(sparse_result, torch_sparse_result)
assert torch.allclose(X.grad, XX.grad)
assert torch.allclose(sparse_result, torch_sparse_result, atol=1e-05)
assert torch.allclose(X.grad, XX.grad, atol=1e-05)
assert torch.allclose(
adj.grad.coalesce().to_dense(),
sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05,
)
......@@ -44,7 +44,7 @@ def test_sddmm(create_func, shape, nnz, hidden):
dense_val = dense_result[row, col] * A_val_clone
dense_val.backward(grad)
assert torch.allclose(dense_val, sparse_result.val)
assert torch.allclose(dense_C.grad, C.grad)
assert torch.allclose(dense_B.grad, B.grad)
assert torch.allclose(A_val_clone.grad, A.val.grad)
assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment