Commit 4bc050b1 authored by rusty1s's avatar rusty1s
Browse files

add torch.jit.script support to spmm

parent 0adb9cab
# import torch
from torch import Tensor
from torch_scatter import scatter_add
def spmm(index, value, m, n, matrix):
def spmm(index: Tensor, value: Tensor, m: int, n: int,
matrix: Tensor) -> Tensor:
"""Matrix product of sparse matrix with dense matrix.
Args:
......@@ -17,7 +18,7 @@ def spmm(index, value, m, n, matrix):
assert n == matrix.size(-2)
row, col = index
row, col = index[0], index[1]
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix.index_select(-2, col)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment