"vscode:/vscode.git/clone" did not exist on "ff9f05c5e361990422c3bda801344aba4b71c127"
Unverified Commit 3eac464a authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Fix the sddmm backward problem in mock_sparse. (#5070)

parent 4aff59f7
......@@ -54,7 +54,9 @@ def sddmm(
"scalar values. "
)
# PyTorch's sddmm operator only supports CSR format.
res = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), mat1, mat2)
res = torch.sparse.sampled_addmm(
A.adj.to_sparse_csr(), mat1, mat2
).to_sparse_coo()
return SparseMatrix(A.row, A.col, res.values(), A.adj.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment