storage.py 8.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
5
6
import torch
from torch_scatter import scatter_add, segment_add


rusty1s's avatar
rusty1s committed
7
8
def optional(func, src):
    return func(src) if src is not None else src
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
12
13
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
            setattr(obj, f'_{self.func.__name__}', value)
        return value


rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
35
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
36
37
38
39
40
41
42
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
43
44
45

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
46
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
47
48

        if value is not None:
rusty1s's avatar
rusty1s committed
49
50
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
51
52
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
53
54
55
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
56
57
58
59
60
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
61
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
62
63
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
64
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
67
68
69
70
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
71
        if colptr is not None:
rusty1s's avatar
rusty1s committed
72
73
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
74
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
75

rusty1s's avatar
rusty1s committed
76
77
78
79
80
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
81

rusty1s's avatar
rusty1s committed
82
83
84
85
86
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
87

rusty1s's avatar
rusty1s committed
88
89
90
91
92
93
94
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
            if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
95
96
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
97

rusty1s's avatar
rusty1s committed
98
99
100
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
101
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
102
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
103
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
104
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
105
106
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
107
108

    @property
rusty1s's avatar
rusty1s committed
109
110
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
111
112

    @property
rusty1s's avatar
rusty1s committed
113
114
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
115
116

    @property
rusty1s's avatar
rusty1s committed
117
118
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
119

rusty1s's avatar
rusty1s committed
120
121
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
122
123

    @property
rusty1s's avatar
rusty1s committed
124
125
126
127
128
129
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
130
131
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
132
133
134
135
136
        return self.apply_value_(lambda x: value)

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
137
138
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
139
        return self.apply_value(lambda x: value)
rusty1s's avatar
rusty1s committed
140
141

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
142
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
143
144
145

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
146
        self._sparse_size == sizes
rusty1s's avatar
rusty1s committed
147
        return self
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
150
151
152
153
    @cached_property
    def rowcount(self):
        one = torch.ones_like(self.row)
        return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])

rusty1s's avatar
rusty1s committed
154
155
    @cached_property
    def rowptr(self):
rusty1s's avatar
rusty1s committed
156
157
158
159
160
161
162
        rowcount = self.rowcount
        return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)

    @cached_property
    def colcount(self):
        one = torch.ones_like(self.col)
        return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
rusty1s's avatar
rusty1s committed
163
164
165

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
166
167
        colcount = self.colcount
        return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
168
169

    @cached_property
rusty1s's avatar
rusty1s committed
170
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
171
172
173
174
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

    @cached_property
rusty1s's avatar
rusty1s committed
175
176
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
177

rusty1s's avatar
rusty1s committed
178
179
180
181
182
183
    def is_coalesced(self):
        raise NotImplementedError

    def coalesce(self):
        raise NotImplementedError

rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
190
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
191
192
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
193
        return self
rusty1s's avatar
rusty1s committed
194

rusty1s's avatar
rusty1s committed
195
196
197
198
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
199

rusty1s's avatar
rusty1s committed
200
201
202
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
203
204
205
206
207
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
208
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
209
210
        return new_storage

rusty1s's avatar
rusty1s committed
211
212
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
213
        return self
rusty1s's avatar
rusty1s committed
214

rusty1s's avatar
rusty1s committed
215
216
217
218
219
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
220
            self._rowcount,
rusty1s's avatar
rusty1s committed
221
            self._rowptr,
rusty1s's avatar
rusty1s committed
222
            self._colcount,
rusty1s's avatar
rusty1s committed
223
            self._colptr,
rusty1s's avatar
rusty1s committed
224
225
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
226
227
228
229
230
231
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
232
233
        for key in self.cached_keys():
            setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
rusty1s's avatar
rusty1s committed
234
        return self
rusty1s's avatar
rusty1s committed
235
236
237
238
239
240

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
241
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
242
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
243
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
244
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
245
246
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
247
248
249
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
250
251
252
253
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
254
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
255
256
        return data

rusty1s's avatar
rusty1s committed
257
258

if __name__ == '__main__':
rusty1s's avatar
test  
rusty1s committed
259
    from torch_geometric.datasets import Reddit, Planetoid  # noqa
rusty1s's avatar
rusty1s committed
260
    import time  # noqa
rusty1s's avatar
test  
rusty1s committed
261
    import copy  # noqa
rusty1s's avatar
rusty1s committed
262
263

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
rusty1s's avatar
test  
rusty1s committed
264
265
    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/Cora', 'Cora')
rusty1s's avatar
rusty1s committed
266
267
    data = dataset[0].to(device)
    edge_index = data.edge_index
rusty1s's avatar
sorting  
rusty1s committed
268

rusty1s's avatar
rusty1s committed
269
270
    storage = SparseStorage(edge_index, is_sorted=True)
    t = time.perf_counter()
rusty1s's avatar
rusty1s committed
271
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
272
273
    print(time.perf_counter() - t)
    t = time.perf_counter()
rusty1s's avatar
test  
rusty1s committed
274
    storage.clear_cache_()
rusty1s's avatar
rusty1s committed
275
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
276
    print(time.perf_counter() - t)
rusty1s's avatar
test  
rusty1s committed
277
    print(storage)
rusty1s's avatar
rusty1s committed
278
    # storage = storage.clone()
rusty1s's avatar
test  
rusty1s committed
279
    # print(storage)
rusty1s's avatar
rusty1s committed
280
281
282
283
284
285
    storage = copy.copy(storage)
    print(storage)
    print(id(storage))
    storage = copy.deepcopy(storage)
    print(storage)
    storage.fill_cache_()
rusty1s's avatar
test  
rusty1s committed
286
    storage.clear_cache_()