"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "01a80807de9727fe9ccb1b35d1ea447647738111"
Unverified Commit 56ce60b0 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Fix SpMM and SDDMM test precision problem (#5060)

parent 1f2fcae3
...@@ -44,9 +44,10 @@ def test_spmm(create_func, shape, nnz, out_dim): ...@@ -44,9 +44,10 @@ def test_spmm(create_func, shape, nnz, out_dim):
if out_dim is None: if out_dim is None:
torch_sparse_result = torch_sparse_result.view(-1) torch_sparse_result = torch_sparse_result.view(-1)
torch_sparse_result.backward(grad) torch_sparse_result.backward(grad)
assert torch.allclose(sparse_result, torch_sparse_result) assert torch.allclose(sparse_result, torch_sparse_result, atol=1e-05)
assert torch.allclose(X.grad, XX.grad) assert torch.allclose(X.grad, XX.grad, atol=1e-05)
assert torch.allclose( assert torch.allclose(
adj.grad.coalesce().to_dense(), adj.grad.coalesce().to_dense(),
sparse_matrix_to_dense(val_like(A, A.val.grad)), sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05,
) )
...@@ -44,7 +44,7 @@ def test_sddmm(create_func, shape, nnz, hidden): ...@@ -44,7 +44,7 @@ def test_sddmm(create_func, shape, nnz, hidden):
dense_val = dense_result[row, col] * A_val_clone dense_val = dense_result[row, col] * A_val_clone
dense_val.backward(grad) dense_val.backward(grad)
assert torch.allclose(dense_val, sparse_result.val) assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
assert torch.allclose(dense_C.grad, C.grad) assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad) assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad) assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment