Commit 9a4ea01b authored by Feng's avatar Feng
Browse files

modify spmm to support multi-dimensional tensor

a simple modification to the code of spmm() can make it support also 3D tensor (might think as batched):
suppose dense matrix has dimensions (B, N, F), the nonzero has dimensions (B, num_edges), all graphs share the same topology (i.e., the same adjacency matrix)
parent 2ab84ed5
...@@ -15,13 +15,13 @@ def spmm(index, value, m, n, matrix): ...@@ -15,13 +15,13 @@ def spmm(index, value, m, n, matrix):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
assert n == matrix.size(0) assert n == matrix.shape[-2]
row, col = index row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix[col] out = matrix[..., col, :]
out = out * value.unsqueeze(-1) out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=0, dim_size=m) out = scatter_add(out, row, dim=-2, dim_size=m)
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment