Commit 4bc050b1 authored by rusty1s's avatar rusty1s
Browse files

add torch.jit.script support to spmm

parent 0adb9cab
# import torch from torch import Tensor
from torch_scatter import scatter_add from torch_scatter import scatter_add
def spmm(index, value, m, n, matrix): def spmm(index: Tensor, value: Tensor, m: int, n: int,
matrix: Tensor) -> Tensor:
"""Matrix product of sparse matrix with dense matrix. """Matrix product of sparse matrix with dense matrix.
Args: Args:
...@@ -17,7 +18,7 @@ def spmm(index, value, m, n, matrix): ...@@ -17,7 +18,7 @@ def spmm(index, value, m, n, matrix):
assert n == matrix.size(-2) assert n == matrix.size(-2)
row, col = index row, col = index[0], index[1]
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix.index_select(-2, col) out = matrix.index_select(-2, col)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment