storage.py 11.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
5

6
7
from torch_sparse import rowptr_cpu

rusty1s's avatar
rusty1s committed
8
try:
9
    from torch_sparse import rowptr_cuda
rusty1s's avatar
rusty1s committed
10
11
except ImportError:
    rowptr_cuda = None
12

rusty1s's avatar
typo  
rusty1s committed
13
__cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16

def is_cache_enabled():
rusty1s's avatar
typo  
rusty1s committed
17
    return __cache__['enabled']
rusty1s's avatar
rusty1s committed
18
19
20


def set_cache_enabled(mode):
rusty1s's avatar
typo  
rusty1s committed
21
    __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
39
40


rusty1s's avatar
rusty1s committed
41
42
43
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
47
48
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
rusty1s's avatar
typo  
rusty1s committed
49
            if is_cache_enabled():
rusty1s's avatar
rusty1s committed
50
                setattr(obj, f'_{self.func.__name__}', value)
rusty1s's avatar
rusty1s committed
51
52
53
        return value


rusty1s's avatar
rusty1s committed
54
55
56
57
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
68
69
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
70
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
71
72
73
74
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
75
76
77
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
78
79
80

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
81
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
82
83

        if value is not None:
rusty1s's avatar
rusty1s committed
84
85
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
86
87
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
88
89
90
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
91
92
93
94
95
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
96
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
97
98
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
99
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
100

rusty1s's avatar
rusty1s committed
101
102
103
104
105
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
106
        if colptr is not None:
rusty1s's avatar
rusty1s committed
107
108
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
109
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
112
113
114
115
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
120
121
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
122

rusty1s's avatar
rusty1s committed
123
124
125
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
rusty1s's avatar
rusty1s committed
126
            if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
rusty1s's avatar
rusty1s committed
127
128
129
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
130
131
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
132

rusty1s's avatar
rusty1s committed
133
134
135
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
136
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
137
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
138
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
139
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
140
141
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
142
143

    @property
rusty1s's avatar
rusty1s committed
144
145
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
146
147

    @property
rusty1s's avatar
rusty1s committed
148
149
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
150
151

    @property
rusty1s's avatar
rusty1s committed
152
153
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
157
158

    @property
rusty1s's avatar
rusty1s committed
159
160
161
162
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
rusty1s's avatar
rusty1s committed
163
164
165
        if isinstance(value, int) or isinstance(value, float):
            value = torch.full((self.nnz(), ), device=self.index.device)
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
166
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
167
168
169
        if torch.is_tensor(value):
            assert value.device == self.index.device
            assert value.size(0) == self.index.size(1)
rusty1s's avatar
rusty1s committed
170
171
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
172
173

    def set_value(self, value, layout=None):
rusty1s's avatar
rusty1s committed
174
175
176
177
        if isinstance(value, int) or isinstance(value, float):
            value = torch.full((self.nnz(), ), device=self.index.device)
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
178
179
180
        if torch.is_tensor(value):
            assert value.device == self._index.device
            assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
181
182
183
184
185
186
187
188
189
190
191
192
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
193
194

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
195
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
196

rusty1s's avatar
rusty1s committed
197
    def sparse_resize(self, *sizes):
rusty1s's avatar
rusty1s committed
198
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        old_sizes, nnz = self.sparse_size(), self.nnz()

        diff_0 = sizes[0] - old_sizes[0]
        rowcount, rowptr = self._rowcount, self._rowptr
        if diff_0 > 0:
            if self.has_rowcount():
                rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
            if self.has_rowptr():
                rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
        else:
            if self.has_rowcount():
                rowcount = rowcount[:-diff_0]
            if self.has_rowptr():
                rowptr = rowptr[:-diff_0]

        diff_1 = sizes[1] - old_sizes[1]
        colcount, colptr = self._colcount, self._colptr
        if diff_1 > 0:
            if self.has_colcount():
                colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
            if self.has_colptr():
                colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
        else:
            if self.has_colcount():
                colcount = colcount[:-diff_1]
            if self.has_colptr():
                colptr = colptr[:-diff_1]

        return self.__class__(
            self._index,
            self._value,
            sizes,
            rowcount=rowcount,
            rowptr=rowptr,
            colcount=colcount,
            colptr=colptr,
            csr2csc=self._csr2csc,
            csc2csr=self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
239

rusty1s's avatar
rusty1s committed
240
241
242
    def has_rowcount(self):
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
243
244
    @cached_property
    def rowcount(self):
245
246
        rowptr = self.rowptr
        return rowptr[1:] - rowptr[:-1]
rusty1s's avatar
rusty1s committed
247

rusty1s's avatar
rusty1s committed
248
249
250
    def has_rowptr(self):
        return self._rowptr is not None

rusty1s's avatar
rusty1s committed
251
252
    @cached_property
    def rowptr(self):
253
254
        func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
        return func.rowptr(self.row, self.sparse_size(0))
rusty1s's avatar
rusty1s committed
255

rusty1s's avatar
rusty1s committed
256
257
258
    def has_colcount(self):
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
259
260
    @cached_property
    def colcount(self):
rusty1s's avatar
typos  
rusty1s committed
261
        if self.has_colptr():
rusty1s's avatar
rusty1s committed
262
263
264
265
266
267
268
269
            colptr = self.colptr
            return colptr[1:] - colptr[:-1]
        else:
            col, dim_size = self.col, self.sparse_size(1)
            return scatter_add(torch.ones_like(col), col, dim_size=dim_size)

    def has_colptr(self):
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
270
271
272

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
273
        if self.has_csr2csc():
rusty1s's avatar
rusty1s committed
274
275
276
277
278
279
280
281
282
283
            func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
            return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
        else:
            colcount = self.colcount
            colptr = colcount.new_zeros(colcount.size(0) + 1)
            torch.cumsum(colcount, dim=0, out=colptr[1:])
            return colptr

    def has_csr2csc(self):
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
284
285

    @cached_property
rusty1s's avatar
rusty1s committed
286
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
287
288
289
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

rusty1s's avatar
rusty1s committed
290
291
292
    def has_csc2csr(self):
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
293
    @cached_property
rusty1s's avatar
rusty1s committed
294
295
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
296

rusty1s's avatar
rusty1s committed
297
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
298
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
299
300
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
301

rusty1s's avatar
rusty1s committed
302
303
304
305
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

rusty1s's avatar
rusty1s committed
306
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
307
308
309
310
311
312
313
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            idx = mask.cumsum(0) - 1
rusty1s's avatar
rusty1s committed
314
            value = segment_csr(idx, value, reduce=reduce)
rusty1s's avatar
rusty1s committed
315
316
317
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
318

rusty1s's avatar
rusty1s committed
319
320
321
322
323
324
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
325
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
326
327
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
328
        return self
rusty1s's avatar
rusty1s committed
329

rusty1s's avatar
rusty1s committed
330
331
332
333
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
334

rusty1s's avatar
rusty1s committed
335
336
337
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
338
339
340
341
342
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
343
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
344
345
        return new_storage

rusty1s's avatar
rusty1s committed
346
347
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
348
        return self
rusty1s's avatar
rusty1s committed
349

rusty1s's avatar
rusty1s committed
350
351
352
353
354
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
355
            self._rowcount,
rusty1s's avatar
rusty1s committed
356
            self._rowptr,
rusty1s's avatar
rusty1s committed
357
            self._colcount,
rusty1s's avatar
rusty1s committed
358
            self._colptr,
rusty1s's avatar
rusty1s committed
359
360
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
361
362
363
364
365
366
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
367
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
368
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
369
        return self
rusty1s's avatar
rusty1s committed
370
371
372
373
374
375

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
376
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
377
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
378
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
379
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
380
381
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
382
383
384
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
385
386
387
388
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
389
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
390
        return data