Unverified Commit 04997b00 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #102 from shi27feng/master

modify spmm to support multi-dimensional tensor
parents 5fc07a56 4f8a6cc9
...@@ -15,13 +15,13 @@ def spmm(index, value, m, n, matrix): ...@@ -15,13 +15,13 @@ def spmm(index, value, m, n, matrix):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
assert n == matrix.size(0) assert n == matrix.size(-2)
row, col = index row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix[col] out = matrix.index_select(-2, col)
out = out * value.unsqueeze(-1) out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=0, dim_size=m) out = scatter_add(out, row, dim=-2, dim_size=m)
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment