Unverified Commit 3eac464a authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Fix the sddmm backward problem in mock_sparse. (#5070)

parent 4aff59f7
...@@ -54,7 +54,9 @@ def sddmm( ...@@ -54,7 +54,9 @@ def sddmm(
"scalar values. " "scalar values. "
) )
# PyTorch's sddmm operator only supports CSR format. # PyTorch's sddmm operator only supports CSR format.
res = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), mat1, mat2) res = torch.sparse.sampled_addmm(
A.adj.to_sparse_csr(), mat1, mat2
).to_sparse_coo()
return SparseMatrix(A.row, A.col, res.values(), A.adj.shape) return SparseMatrix(A.row, A.col, res.values(), A.adj.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment