test_convert.py 717 Bytes
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse


def test_convert_scipy():
    index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
    value = torch.Tensor([1, 2, 4, 1, 3])
    N = 3

    out = from_scipy(to_scipy(index, value, N, N))
    assert out[0].tolist() == index.tolist()
    assert out[1].tolist() == value.tolist()


def test_convert_torch_sparse():
    index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
    value = torch.Tensor([1, 2, 4, 1, 3])
    N = 3

    out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
    assert out[0].tolist() == index.tolist()
    assert out[1].tolist() == value.tolist()