test_storage.py 5.56 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import copy
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
6
from torch_sparse.storage import SparseStorage, no_cache
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
rusty1s's avatar
rusty1s committed
13
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
16
17
18
    assert storage.row.tolist() == [0, 0, 1, 1]
    assert storage.col.tolist() == [0, 1, 0, 1]
    assert storage.value is None
rusty1s's avatar
rusty1s committed
19
    assert storage.sparse_size == (2, 2)
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
22
    value = tensor([2, 1, 4, 3], dtype, device)
rusty1s's avatar
rusty1s committed
23
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
24
25
26
    assert storage.row.tolist() == [0, 0, 1, 1]
    assert storage.col.tolist() == [0, 1, 0, 1]
    assert storage.value.tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
27
    assert storage.sparse_size == (2, 2)
rusty1s's avatar
rusty1s committed
28
29
30
31


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
rusty1s's avatar
rusty1s committed
32
33
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
36
    assert storage._row.tolist() == row.tolist()
    assert storage._col.tolist() == col.tolist()
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    assert storage._value is None

    assert storage._rowcount is None
    assert storage._rowptr is None
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
    assert storage.cached_keys() == []

    storage.fill_cache_()
    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
    assert storage.cached_keys() == [
rusty1s's avatar
rusty1s committed
54
        'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
rusty1s's avatar
rusty1s committed
55
56
    ]

rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
    storage = SparseStorage(row=row, rowptr=storage.rowptr, col=col,
                            value=storage.value,
                            sparse_size=storage.sparse_size,
                            rowcount=storage.rowcount, colptr=storage.colptr,
                            colcount=storage.colcount, csr2csc=storage.csr2csc,
                            csc2csr=storage.csc2csr)
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70

    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
    assert storage.cached_keys() == [
rusty1s's avatar
rusty1s committed
71
        'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
rusty1s's avatar
rusty1s committed
72
73
74
75
    ]

    storage.clear_cache_()
    assert storage._rowcount is None
rusty1s's avatar
rusty1s committed
76
    assert storage._rowptr is not None
rusty1s's avatar
rusty1s committed
77
78
79
80
81
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
    assert storage.cached_keys() == []

rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
88
89
90
91
92
    with no_cache():
        storage.fill_cache_()
    assert storage.cached_keys() == []

    @no_cache()
    def do_something(storage):
        return storage.fill_cache_()

    storage = do_something(storage)
    assert storage.cached_keys() == []

rusty1s's avatar
rusty1s committed
93
94
95

@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
rusty1s's avatar
rusty1s committed
96
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
97
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
98
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
106
107
108
109
110
111

    assert storage.has_value()

    storage.set_value_(value, layout='csc')
    assert storage.value.tolist() == [1, 3, 2, 4]
    storage.set_value_(value, layout='coo')
    assert storage.value.tolist() == [1, 2, 3, 4]

    storage = storage.set_value(value, layout='csc')
    assert storage.value.tolist() == [1, 3, 2, 4]
    storage = storage.set_value(value, layout='coo')
    assert storage.value.tolist() == [1, 2, 3, 4]

rusty1s's avatar
rusty1s committed
112
113
    storage = storage.sparse_resize(3, 3)
    assert storage.sparse_size == (3, 3)
rusty1s's avatar
rusty1s committed
114
115
116

    new_storage = copy.copy(storage)
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
117
    assert new_storage.col.data_ptr() == storage.col.data_ptr()
rusty1s's avatar
rusty1s committed
118
119
120

    new_storage = storage.clone()
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
121
    assert new_storage.col.data_ptr() != storage.col.data_ptr()
rusty1s's avatar
rusty1s committed
122
123
124

    new_storage = copy.deepcopy(storage)
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
125
    assert new_storage.col.data_ptr() != storage.col.data_ptr()
rusty1s's avatar
rusty1s committed
126
127
128
129
130
131
132

    storage.apply_value_(lambda x: x + 1)
    assert storage.value.tolist() == [2, 3, 4, 5]
    storage = storage.apply_value(lambda x: x + 1)
    assert storage.value.tolist() == [3, 4, 5, 6]

    storage.apply_(lambda x: x.to(torch.long))
rusty1s's avatar
rusty1s committed
133
    assert storage.col.dtype == torch.long
rusty1s's avatar
rusty1s committed
134
135
136
    assert storage.value.dtype == torch.long

    storage = storage.apply(lambda x: x.to(torch.long))
rusty1s's avatar
rusty1s committed
137
    assert storage.col.dtype == torch.long
rusty1s's avatar
rusty1s committed
138
139
140
    assert storage.value.dtype == torch.long

    storage.clear_cache_()
rusty1s's avatar
rusty1s committed
141
    assert storage.map(lambda x: x.numel()) == [4, 4, 4]
rusty1s's avatar
rusty1s committed
142
143
144
145


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
rusty1s's avatar
rusty1s committed
146
    row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
147
    value = tensor([1, 1, 1, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
148
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
149

rusty1s's avatar
rusty1s committed
150
151
    assert storage.row.tolist() == row.tolist()
    assert storage.col.tolist() == col.tolist()
rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
    assert storage.value.tolist() == value.tolist()

    assert not storage.is_coalesced()
    storage = storage.coalesce()
    assert storage.is_coalesced()

rusty1s's avatar
rusty1s committed
158
159
    assert storage.row.tolist() == [0, 0, 1, 1]
    assert storage.col.tolist() == [0, 1, 0, 1]
rusty1s's avatar
rusty1s committed
160
    assert storage.value.tolist() == [1, 2, 3, 4]