test_storage.py 5.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_sparse.storage import SparseStorage
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
rusty1s's avatar
rusty1s committed
12
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
15
16
17
18
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value() is None
    assert storage.sparse_sizes() == (2, 2)
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
21
    value = tensor([2, 1, 4, 3], dtype, device)
rusty1s's avatar
rusty1s committed
22
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
23
24
25
26
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
    assert storage.sparse_sizes() == (2, 2)
rusty1s's avatar
rusty1s committed
27
28
29
30


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
rusty1s's avatar
rusty1s committed
31
32
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
    assert storage._row.tolist() == row.tolist()
    assert storage._col.tolist() == col.tolist()
rusty1s's avatar
rusty1s committed
36
37
38
39
40
41
42
    assert storage._value is None

    assert storage._rowcount is None
    assert storage._rowptr is None
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
rusty1s's avatar
rusty1s committed
43
    assert storage.num_cached_keys() == 0
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51

    storage.fill_cache_()
    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
rusty1s's avatar
rusty1s committed
52
    assert storage.num_cached_keys() == 5
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
    storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
                            value=storage._value,
                            sparse_sizes=storage._sparse_sizes,
                            rowcount=storage._rowcount, colptr=storage._colptr,
                            colcount=storage._colcount,
                            csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66

    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
rusty1s's avatar
rusty1s committed
67
    assert storage.num_cached_keys() == 5
rusty1s's avatar
rusty1s committed
68
69
70

    storage.clear_cache_()
    assert storage._rowcount is None
rusty1s's avatar
rusty1s committed
71
    assert storage._rowptr is not None
rusty1s's avatar
rusty1s committed
72
73
74
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
rusty1s's avatar
rusty1s committed
75
    assert storage.num_cached_keys() == 0
rusty1s's avatar
rusty1s committed
76
77
78
79


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
rusty1s's avatar
rusty1s committed
80
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
81
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
82
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
83
84
85
86

    assert storage.has_value()

    storage.set_value_(value, layout='csc')
rusty1s's avatar
rusty1s committed
87
    assert storage.value().tolist() == [1, 3, 2, 4]
rusty1s's avatar
rusty1s committed
88
    storage.set_value_(value, layout='coo')
rusty1s's avatar
rusty1s committed
89
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
90
91

    storage = storage.set_value(value, layout='csc')
rusty1s's avatar
rusty1s committed
92
    assert storage.value().tolist() == [1, 3, 2, 4]
rusty1s's avatar
rusty1s committed
93
    storage = storage.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
94
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
97
    storage = storage.sparse_resize((3, 3))
    assert storage.sparse_sizes() == (3, 3)
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
    new_storage = storage.copy()
rusty1s's avatar
rusty1s committed
100
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
101
    assert new_storage.col().data_ptr() == storage.col().data_ptr()
rusty1s's avatar
rusty1s committed
102
103
104

    new_storage = storage.clone()
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
105
    assert new_storage.col().data_ptr() != storage.col().data_ptr()
rusty1s's avatar
rusty1s committed
106
107
108
109


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
rusty1s's avatar
rusty1s committed
110
    row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
111
    value = tensor([1, 1, 1, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
112
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
115
116
    assert storage.row().tolist() == row.tolist()
    assert storage.col().tolist() == col.tolist()
    assert storage.value().tolist() == value.tolist()
rusty1s's avatar
rusty1s committed
117
118
119
120
121

    assert not storage.is_coalesced()
    storage = storage.coalesce()
    assert storage.is_coalesced()

rusty1s's avatar
rusty1s committed
122
123
124
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
    row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
    storage = SparseStorage(row=row, col=col)

    storage = storage.sparse_reshape(2, 8)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]

    storage = storage.sparse_reshape(-1, 4)
    assert storage.sparse_sizes() == (4, 4)
    assert storage.row().tolist() == [0, 1, 2, 3]
    assert storage.col().tolist() == [0, 1, 2, 3]

    storage = storage.sparse_reshape(2, -1)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]