storage.py 15.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2
from typing import Optional, List, Dict, Any
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
6
from torch_sparse.utils import Final
7

rusty1s's avatar
typo  
rusty1s committed
8
__cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11

def is_cache_enabled():
rusty1s's avatar
typo  
rusty1s committed
12
    return __cache__['enabled']
rusty1s's avatar
rusty1s committed
13
14
15


def set_cache_enabled(mode):
rusty1s's avatar
typo  
rusty1s committed
16
    __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
34
35


rusty1s's avatar
rusty1s committed
36
37
38
# class cached_property(object):
#     def __init__(self, func):
#         self.func = func
rusty1s's avatar
sorting  
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
#     def __get__(self, obj, cls):
#         value = getattr(obj, f'_{self.func.__name__}', None)
#         if value is None:
#             value = self.func(obj)
#             if is_cache_enabled():
#                 setattr(obj, f'_{self.func.__name__}', value)
#         return value
rusty1s's avatar
rusty1s committed
47
48


rusty1s's avatar
rusty1s committed
49
50
51
52
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
53
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
61
62
63
64


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
65
@torch.jit.script
rusty1s's avatar
rusty1s committed
66
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    _row: Optional[torch.Tensor]
    _rowptr: Optional[torch.Tensor]
    _col: torch.Tensor
    _value: Optional[torch.Tensor]
    _sparse_size: List[int]
    _rowcount: Optional[torch.Tensor]
    _colptr: Optional[torch.Tensor]
    _colcount: Optional[torch.Tensor]
    _csr2csc: Optional[torch.Tensor]
    _csc2csr: Optional[torch.Tensor]

    def __init__(self, row: Optional[torch.Tensor] = None,
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
                 sparse_size: Optional[List[int]] = None,
                 rowcount: Optional[torch.Tensor] = None,
                 colptr: Optional[torch.Tensor] = None,
                 colcount: Optional[torch.Tensor] = None,
                 csr2csc: Optional[torch.Tensor] = None,
                 csc2csr: Optional[torch.Tensor] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
89

rusty1s's avatar
rusty1s committed
90
91
92
93
        assert row is not None or rowptr is not None
        assert col is not None
        assert col.dtype == torch.long
        assert col.dim() == 1
rusty1s's avatar
rusty1s committed
94
        col = col.contiguous()
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
        if sparse_size is None:
rusty1s's avatar
rusty1s committed
97
98
99
100
101
102
            if rowptr is not None:
                M = rowptr.numel() - 1
            elif row is not None:
                M = row.max().item() + 1
            else:
                raise ValueError
rusty1s's avatar
rusty1s committed
103
            N = col.max().item() + 1
rusty1s's avatar
rusty1s committed
104
105
106
            sparse_size = torch.Size([int(M), int(N)])
        else:
            assert len(sparse_size) == 2
rusty1s's avatar
rusty1s committed
107

rusty1s's avatar
rusty1s committed
108
109
110
111
112
        if row is not None:
            assert row.dtype == torch.long
            assert row.device == col.device
            assert row.dim() == 1
            assert row.numel() == col.numel()
rusty1s's avatar
rusty1s committed
113
            row = row.contiguous()
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
rusty1s committed
115
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
116
            assert rowptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
117
118
119
            assert rowptr.device == col.device
            assert rowptr.dim() == 1
            assert rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
120
            rowptr = rowptr.contiguous()
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
rusty1s committed
122
123
124
        if value is not None:
            assert value.device == col.device
            assert value.size(0) == col.size(0)
rusty1s's avatar
rusty1s committed
125
            value = value.contiguous()
rusty1s's avatar
rusty1s committed
126
127
128
129
130
131

        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == col.device
            assert rowcount.dim() == 1
            assert rowcount.numel() == sparse_size[0]
rusty1s's avatar
rusty1s committed
132
            rowcount = rowcount.contiguous()
rusty1s's avatar
rusty1s committed
133

rusty1s's avatar
rusty1s committed
134
        if colptr is not None:
rusty1s's avatar
rusty1s committed
135
            assert colptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
136
137
138
            assert colptr.device == col.device
            assert colptr.dim() == 1
            assert colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
139
            colptr = colptr.contiguous()
rusty1s's avatar
rusty1s committed
140
141
142
143
144
145

        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == col.device
            assert colcount.dim() == 1
            assert colcount.numel() == sparse_size[1]
rusty1s's avatar
rusty1s committed
146
            colcount = colcount.contiguous()
rusty1s's avatar
rusty1s committed
147

rusty1s's avatar
rusty1s committed
148
149
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
rusty1s's avatar
rusty1s committed
150
            assert csr2csc.device == col.device
rusty1s's avatar
rusty1s committed
151
            assert csr2csc.dim() == 1
rusty1s's avatar
rusty1s committed
152
            assert csr2csc.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
153
            csr2csc = csr2csc.contiguous()
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
rusty1s's avatar
rusty1s committed
157
            assert csc2csr.device == col.device
rusty1s's avatar
rusty1s committed
158
            assert csc2csr.dim() == 1
rusty1s's avatar
rusty1s committed
159
            assert csc2csr.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
160
            csc2csr = csc2csr.contiguous()
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
163
164
        self._row = row
        self._rowptr = rowptr
        self._col = col
rusty1s's avatar
rusty1s committed
165
166
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
167
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
168
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
169
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
170
171
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
172

rusty1s's avatar
rusty1s committed
173
        if not is_sorted:
rusty1s's avatar
rusty1s committed
174
175
            idx = col.new_zeros(col.numel() + 1)
            idx[1:] = sparse_size[1] * self.row() + col
rusty1s's avatar
rusty1s committed
176
            if (idx[1:] < idx[:-1]).any():
rusty1s's avatar
rusty1s committed
177
                perm = idx[1:].argsort()
rusty1s's avatar
rusty1s committed
178
179
180
181
                self._row = self.row()[perm]
                self._col = col[perm]
                if value is not None:
                    self._value = value[perm]
rusty1s's avatar
rusty1s committed
182
183
184
                self._csr2csc = None
                self._csc2csr = None

rusty1s's avatar
rusty1s committed
185
    def has_row(self) -> bool:
rusty1s's avatar
rusty1s committed
186
        return self._row is not None
rusty1s's avatar
rusty1s committed
187

rusty1s's avatar
rusty1s committed
188
    def row(self):
rusty1s's avatar
rusty1s committed
189
190
191
        row = self._row
        if row is not None:
            return row
rusty1s's avatar
rusty1s committed
192

rusty1s's avatar
rusty1s committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        rowptr = self._rowptr
        if rowptr is not None:
            if rowptr.is_cuda:
                row = torch.ops.torch_sparse_cuda.ptr2ind(
                    rowptr, self._col.numel())
            else:
                if rowptr.is_cuda:
                    row = torch.ops.torch_sparse_cuda.ptr2ind(
                        rowptr, self._col.numel())
                else:
                    row = torch.ops.torch_sparse_cpu.ptr2ind(
                        rowptr, self._col.numel())
            self._row = row
            return row

        raise ValueError

    def has_rowptr(self) -> bool:
rusty1s's avatar
rusty1s committed
211
212
        return self._rowptr is not None

rusty1s's avatar
rusty1s committed
213
214
215
216
    def rowptr(self) -> torch.Tensor:
        rowptr = self._rowptr
        if rowptr is not None:
            return rowptr
rusty1s's avatar
rusty1s committed
217

rusty1s's avatar
rusty1s committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        row = self._row
        if row is not None:
            if row.is_cuda:
                rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
                    row, self._sparse_size[0])
            else:
                rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
                    row, self._sparse_size[0])
            self._rowptr = rowptr
            return rowptr

        raise ValueError

    def col(self) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
232
        return self._col
rusty1s's avatar
rusty1s committed
233

rusty1s's avatar
rusty1s committed
234
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
235
        return self._value is not None
rusty1s's avatar
rusty1s committed
236

rusty1s's avatar
rusty1s committed
237
    def value(self) -> Optional[torch.Tensor]:
rusty1s's avatar
rusty1s committed
238
239
        return self._value

rusty1s's avatar
rusty1s committed
240
241
242
243
    # def set_value_(self, value, layout=None, dtype=None):
    #     if isinstance(value, int) or isinstance(value, float):
    #         value = torch.full((self.col.numel(), ), dtype=dtype,
    #                            device=self.col.device)
rusty1s's avatar
rusty1s committed
244

rusty1s's avatar
rusty1s committed
245
246
    #     elif torch.is_tensor(value) and get_layout(layout) == 'csc':
    #         value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
247

rusty1s's avatar
rusty1s committed
248
249
250
251
    #     if torch.is_tensor(value):
    #         value = value if dtype is None else value.to(dtype)
    #         assert value.device == self.col.device
    #         assert value.size(0) == self.col.numel()
rusty1s's avatar
rusty1s committed
252

rusty1s's avatar
rusty1s committed
253
254
    #     self._value = value
    #     return self
rusty1s's avatar
rusty1s committed
255

rusty1s's avatar
rusty1s committed
256
257
258
259
    # def set_value(self, value, layout=None, dtype=None):
    #     if isinstance(value, int) or isinstance(value, float):
    #         value = torch.full((self.col.numel(), ), dtype=dtype,
    #                            device=self.col.device)
rusty1s's avatar
rusty1s committed
260

rusty1s's avatar
rusty1s committed
261
262
    #     elif torch.is_tensor(value) and get_layout(layout) == 'csc':
    #         value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
263

rusty1s's avatar
rusty1s committed
264
265
266
267
    #     if torch.is_tensor(value):
    #         value = value if dtype is None else value.to(dtype)
    #         assert value.device == self.col.device
    #         assert value.size(0) == self.col.numel()
rusty1s's avatar
rusty1s committed
268

rusty1s's avatar
rusty1s committed
269
270
271
272
273
    #     return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
    #                           value=value, sparse_size=self._sparse_size,
    #                           rowcount=self._rowcount, colptr=self._colptr,
    #                           colcount=self._colcount, csr2csc=self._csr2csc,
    #                           csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
274

rusty1s's avatar
rusty1s committed
275
    def sparse_size(self) -> List[int]:
rusty1s's avatar
rusty1s committed
276
        return self._sparse_size
rusty1s's avatar
rusty1s committed
277

rusty1s's avatar
rusty1s committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    # def sparse_resize(self, *sizes):
    #     old_sparse_size, nnz = self.sparse_size, self.col.numel()

    #     diff_0 = sizes[0] - old_sparse_size[0]
    #     rowcount, rowptr = self._rowcount, self._rowptr
    #     if diff_0 > 0:
    #         if rowptr is not None:
    #             rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
    #         if rowcount is not None:
    #             rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
    #     else:
    #         if rowptr is not None:
    #             rowptr = rowptr[:-diff_0]
    #         if rowcount is not None:
    #             rowcount = rowcount[:-diff_0]

    #     diff_1 = sizes[1] - old_sparse_size[1]
    #     colcount, colptr = self._colcount, self._colptr
    #     if diff_1 > 0:
    #         if colptr is not None:
    #             colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
    #         if colcount is not None:
    #             colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
    #     else:
    #         if colptr is not None:
    #             colptr = colptr[:-diff_1]
    #         if colcount is not None:
    #             colcount = colcount[:-diff_1]

    #     return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
    #                           value=self.value, sparse_size=sizes,
    #                           rowcount=rowcount, colptr=colptr,
    #                           colcount=colcount, csr2csc=self._csr2csc,
    #                           csc2csr=self._csc2csr, is_sorted=True)

    def has_rowcount(self) -> bool:
rusty1s's avatar
rusty1s committed
314
315
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
316
317
318
319
320
321
322
323
324
    def rowcount(self) -> torch.Tensor:
        rowcount = self._rowcount
        if rowcount is not None:
            return rowcount

        rowptr = self.rowptr()
        rowcount = rowptr[1:] - rowptr[1:]
        self._rowcount = rowcount
        return rowcount
rusty1s's avatar
rusty1s committed
325

rusty1s's avatar
rusty1s committed
326
    def has_colptr(self) -> bool:
rusty1s's avatar
rusty1s committed
327
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
328

rusty1s's avatar
rusty1s committed
329
330
331
    def colptr(self) -> torch.Tensor:
        colptr = self._colptr
        if colptr is not None:
rusty1s's avatar
rusty1s committed
332
            return colptr
rusty1s's avatar
rusty1s committed
333

rusty1s's avatar
rusty1s committed
334
335
336
337
338
339
340
341
342
343
344
        csr2csc = self._csr2csc
        if csr2csc is not None:
            colptr = torch.ops.torch_sparse_cpu.ind2ptr(
                self._col[csr2csc], self._sparse_size[1])
        else:
            colptr = self._col.new_zeros(self._sparse_size[1] + 1)
            torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
        self._colptr = colptr
        return colptr

    def has_colcount(self) -> bool:
rusty1s's avatar
rusty1s committed
345
346
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
347
348
349
350
351
352
353
354
    def colcount(self) -> torch.Tensor:
        colcount = self._colcount
        if colcount is not None:
            return colcount

        colptr = self._colptr
        if colptr is not None:
            colcount = colptr[1:] - colptr[1:]
rusty1s's avatar
rusty1s committed
355
        else:
rusty1s's avatar
rusty1s committed
356
357
358
359
360
            raise NotImplementedError
            # colcount = scatter_add(torch.ones_like(self._col), self._col,
            #                        dim_size=self._sparse_size[1])
        self._colcount = colcount
        return colcount
rusty1s's avatar
rusty1s committed
361

rusty1s's avatar
rusty1s committed
362
    def has_csr2csc(self) -> bool:
rusty1s's avatar
rusty1s committed
363
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
364

rusty1s's avatar
rusty1s committed
365
366
367
368
    def csr2csc(self) -> torch.Tensor:
        csr2csc = self._csr2csc
        if csr2csc is not None:
            return csr2csc
rusty1s's avatar
rusty1s committed
369

rusty1s's avatar
rusty1s committed
370
371
372
373
374
375
        idx = self._sparse_size[0] * self._col + self.row()
        csr2csc = idx.argsort()
        self._csr2csc = csr2csc
        return csr2csc

    def has_csc2csr(self) -> bool:
rusty1s's avatar
rusty1s committed
376
377
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
378
379
380
381
    def csc2csr(self) -> torch.Tensor:
        csc2csr = self._csc2csr
        if csc2csr is not None:
            return csc2csr
rusty1s's avatar
rusty1s committed
382

rusty1s's avatar
rusty1s committed
383
384
385
        csc2csr = self.csr2csc().argsort()
        self._csc2csr = csc2csr
        return csc2csr
rusty1s's avatar
rusty1s committed
386

rusty1s's avatar
rusty1s committed
387
388
389
390
391
392
393
394
    def is_coalesced(self) -> bool:
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_size[1] * self.row() + self._col
        return bool((idx[1:] > idx[:-1]).all())

    def coalesce(self, reduce: str = "add"):
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_size[1] * self.row() + self._col
rusty1s's avatar
rusty1s committed
395
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
rusty1s committed
396

rusty1s's avatar
rusty1s committed
397
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
398
399
            return self

rusty1s's avatar
rusty1s committed
400
401
        row = self.row()[mask]
        col = self._col[mask]
rusty1s's avatar
rusty1s committed
402

rusty1s's avatar
rusty1s committed
403
404
        value = self._value
        if value is not None:
rusty1s's avatar
rusty1s committed
405
406
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
rusty1s's avatar
rusty1s committed
407
408
            raise NotImplementedError
            # value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
409
410
            value = value[0] if isinstance(value, tuple) else value

rusty1s's avatar
rusty1s committed
411
412
413
414
415
416
417
418
419
420
421
422
423
        return SparseStorage(row=row, rowptr=None, col=col, value=value,
                             sparse_size=self._sparse_size, rowcount=None,
                             colptr=None, colcount=None, csr2csc=None,
                             csc2csr=None, is_sorted=True)

    def fill_cache_(self):
        self.row()
        self.rowptr()
        self.rowcount()
        self.colptr()
        self.colcount()
        self.csr2csc()
        self.csc2csr()
rusty1s's avatar
rusty1s committed
424
        return self
rusty1s's avatar
rusty1s committed
425

rusty1s's avatar
rusty1s committed
426
427
428
429
430
431
    def clear_cache_(self):
        self._rowcount = None
        self._colptr = None
        self._colcount = None
        self._csr2csc = None
        self._csc2csr = None
rusty1s's avatar
rusty1s committed
432
        return self
rusty1s's avatar
rusty1s committed
433

rusty1s's avatar
rusty1s committed
434
    def __copy__(self):
rusty1s's avatar
rusty1s committed
435
436
437
438
439
        return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
                             value=self._value, sparse_size=self._sparse_size,
                             rowcount=self._rowcount, colptr=self._colptr,
                             colcount=self._colcount, csr2csc=self._csr2csc,
                             csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
440

rusty1s's avatar
test  
rusty1s committed
441
    def clone(self):
rusty1s's avatar
rusty1s committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        row = self._row
        if row is not None:
            row = row.clone()
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.clone()
        value = self._value
        if value is not None:
            value = value.clone()
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.clone()
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.clone()
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.clone()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.clone()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.clone()
        return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(),
                             value=value, sparse_size=self._sparse_size,
                             rowcount=rowcount, colptr=colptr,
                             colcount=colcount, csr2csc=csr2csc,
                             csc2csr=csc2csr, is_sorted=True)

    def __deepcopy__(self, memo: Dict[str, Any]):
        return self.clone()