test_spmm.py 375 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch_sparse import spmm


def test_spmm():
    row = torch.tensor([0, 0, 1, 2, 2])
    col = torch.tensor([0, 2, 1, 0, 1])
    index = torch.stack([row, col], dim=0)
    value = torch.tensor([1, 2, 4, 1, 3])

    matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
    out = spmm(index, value, 3, matrix)
    assert out.tolist() == [[7, 16], [8, 20], [7, 19]]