Unverified Commit 0698e91a authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Use Pytorch dense computation in SpMM tests (#5096)

parent 774709d3
......@@ -37,18 +37,16 @@ def test_spmm(create_func, shape, nnz, out_dim):
grad = torch.randn_like(sparse_result)
sparse_result.backward(grad)
adj = sparse_matrix_to_torch_sparse(A)
adj = sparse_matrix_to_dense(A)
XX = clone_detach_and_grad(X)
torch_sparse_result = torch.sparse.mm(
adj, XX.view(-1, 1) if out_dim is None else XX
)
dense_result = torch.matmul(adj, XX)
if out_dim is None:
torch_sparse_result = torch_sparse_result.view(-1)
torch_sparse_result.backward(grad)
assert torch.allclose(sparse_result, torch_sparse_result, atol=1e-05)
dense_result = dense_result.view(-1)
dense_result.backward(grad)
assert torch.allclose(sparse_result, dense_result, atol=1e-05)
assert torch.allclose(X.grad, XX.grad, atol=1e-05)
assert torch.allclose(
adj.grad.coalesce().to_dense(),
dense_mask(adj.grad, A),
sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05,
)
......
......@@ -107,8 +107,7 @@ def rand_csc_uncoalesced(shape, nnz, dev):
def sparse_matrix_to_dense(A: SparseMatrix):
dense = A.dense()
dense.requires_grad_()
return dense
return clone_detach_and_grad(dense)
def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment