"tests/vscode:/vscode.git/clone" did not exist on "4b265390f45e6aa8b40d8b090c4c94ffc5402cdc"
Unverified Commit 0698e91a authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Use Pytorch dense computation in SpMM tests (#5096)

parent 774709d3
...@@ -37,18 +37,16 @@ def test_spmm(create_func, shape, nnz, out_dim): ...@@ -37,18 +37,16 @@ def test_spmm(create_func, shape, nnz, out_dim):
grad = torch.randn_like(sparse_result) grad = torch.randn_like(sparse_result)
sparse_result.backward(grad) sparse_result.backward(grad)
adj = sparse_matrix_to_torch_sparse(A) adj = sparse_matrix_to_dense(A)
XX = clone_detach_and_grad(X) XX = clone_detach_and_grad(X)
torch_sparse_result = torch.sparse.mm( dense_result = torch.matmul(adj, XX)
adj, XX.view(-1, 1) if out_dim is None else XX
)
if out_dim is None: if out_dim is None:
torch_sparse_result = torch_sparse_result.view(-1) dense_result = dense_result.view(-1)
torch_sparse_result.backward(grad) dense_result.backward(grad)
assert torch.allclose(sparse_result, torch_sparse_result, atol=1e-05) assert torch.allclose(sparse_result, dense_result, atol=1e-05)
assert torch.allclose(X.grad, XX.grad, atol=1e-05) assert torch.allclose(X.grad, XX.grad, atol=1e-05)
assert torch.allclose( assert torch.allclose(
adj.grad.coalesce().to_dense(), dense_mask(adj.grad, A),
sparse_matrix_to_dense(val_like(A, A.val.grad)), sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05, atol=1e-05,
) )
......
...@@ -107,8 +107,7 @@ def rand_csc_uncoalesced(shape, nnz, dev): ...@@ -107,8 +107,7 @@ def rand_csc_uncoalesced(shape, nnz, dev):
def sparse_matrix_to_dense(A: SparseMatrix): def sparse_matrix_to_dense(A: SparseMatrix):
dense = A.dense() dense = A.dense()
dense.requires_grad_() return clone_detach_and_grad(dense)
return dense
def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None): def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment