test_convert.py 1.75 KB
Newer Older
rusty1s's avatar
fixes  
rusty1s committed
1
import time
2
3
4
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse
rusty1s's avatar
fixes  
rusty1s committed
5
6
from torch_sparse.storage import SparseStorage
from scipy.io import loadmat
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


def test_convert_scipy():
    index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
    value = torch.Tensor([1, 2, 4, 1, 3])
    N = 3

    out = from_scipy(to_scipy(index, value, N, N))
    assert out[0].tolist() == index.tolist()
    assert out[1].tolist() == value.tolist()


def test_convert_torch_sparse():
    index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
    value = torch.Tensor([1, 2, 4, 1, 3])
    N = 3

    out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
    assert out[0].tolist() == index.tolist()
    assert out[1].tolist() == value.tolist()
rusty1s's avatar
fixes  
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


def test_ind2ptr():
    name = ('DIMACS10', 'citationCiteseer')[1]
    mat = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2]
    mat = mat.tocsr().tocoo()

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long).cuda()
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long).cuda()
    col = torch.from_numpy(mat.col).to(torch.long).cuda()

    storage = SparseStorage(row=row, col=col)
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        storage.rowptr
        storage._rowptr = None
    torch.cuda.synchronize()
    print(time.perf_counter() - t)

    assert storage.rowptr.tolist() == rowptr.tolist()

    storage = SparseStorage(rowptr=rowptr, col=col)
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        storage.row
        storage._row = None
    torch.cuda.synchronize()
    print(time.perf_counter() - t)

    assert storage.row.tolist() == row.tolist()