storage.py 15.3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2
from typing import Optional, List, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
6
from torch_sparse.utils import Final, is_scalar
7

rusty1s's avatar
rusty1s committed
8
# __cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
# def is_cache_enabled():
#     return __cache__['enabled']
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14
# def set_cache_enabled(mode):
#     __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
19
# class no_cache(object):
#     def __enter__(self):
#         self.prev = is_cache_enabled()
#         set_cache_enabled(False)
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
#     def __exit__(self, *args):
#         set_cache_enabled(self.prev)
#         return False
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
28
#     def __call__(self, func):
#         def decorate_no_cache(*args, **kwargs):
#             with self:
#                 return func(*args, **kwargs)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
#         return decorate_no_cache
rusty1s's avatar
rusty1s committed
31
32


rusty1s's avatar
rusty1s committed
33
34
35
36
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
37
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
rusty1s's avatar
rusty1s committed
38
39


rusty1s's avatar
rusty1s committed
40
def get_layout(layout: Optional[str] = None) -> str:
rusty1s's avatar
rusty1s committed
41
42
43
44
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
rusty1s's avatar
rusty1s committed
45
    assert layout == 'coo' or layout == 'csr' or layout == 'csc'
rusty1s's avatar
rusty1s committed
46
47
48
    return layout


rusty1s's avatar
rusty1s committed
49
@torch.jit.script
rusty1s's avatar
rusty1s committed
50
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    _row: Optional[torch.Tensor]
    _rowptr: Optional[torch.Tensor]
    _col: torch.Tensor
    _value: Optional[torch.Tensor]
    _sparse_size: List[int]
    _rowcount: Optional[torch.Tensor]
    _colptr: Optional[torch.Tensor]
    _colcount: Optional[torch.Tensor]
    _csr2csc: Optional[torch.Tensor]
    _csc2csr: Optional[torch.Tensor]

    def __init__(self, row: Optional[torch.Tensor] = None,
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
                 sparse_size: Optional[List[int]] = None,
                 rowcount: Optional[torch.Tensor] = None,
                 colptr: Optional[torch.Tensor] = None,
                 colcount: Optional[torch.Tensor] = None,
                 csr2csc: Optional[torch.Tensor] = None,
                 csc2csr: Optional[torch.Tensor] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
77
        assert row is not None or rowptr is not None
        assert col is not None
        assert col.dtype == torch.long
        assert col.dim() == 1
rusty1s's avatar
rusty1s committed
78
        col = col.contiguous()
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
        if sparse_size is None:
rusty1s's avatar
rusty1s committed
81
82
83
84
85
86
            if rowptr is not None:
                M = rowptr.numel() - 1
            elif row is not None:
                M = row.max().item() + 1
            else:
                raise ValueError
rusty1s's avatar
rusty1s committed
87
            N = col.max().item() + 1
rusty1s's avatar
rusty1s committed
88
89
90
            sparse_size = torch.Size([int(M), int(N)])
        else:
            assert len(sparse_size) == 2
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
93
94
95
96
        if row is not None:
            assert row.dtype == torch.long
            assert row.device == col.device
            assert row.dim() == 1
            assert row.numel() == col.numel()
rusty1s's avatar
rusty1s committed
97
            row = row.contiguous()
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
100
            assert rowptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
101
102
103
            assert rowptr.device == col.device
            assert rowptr.dim() == 1
            assert rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
104
            rowptr = rowptr.contiguous()
rusty1s's avatar
rusty1s committed
105

rusty1s's avatar
rusty1s committed
106
107
108
        if value is not None:
            assert value.device == col.device
            assert value.size(0) == col.size(0)
rusty1s's avatar
rusty1s committed
109
            value = value.contiguous()
rusty1s's avatar
rusty1s committed
110
111
112
113
114
115

        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == col.device
            assert rowcount.dim() == 1
            assert rowcount.numel() == sparse_size[0]
rusty1s's avatar
rusty1s committed
116
            rowcount = rowcount.contiguous()
rusty1s's avatar
rusty1s committed
117

rusty1s's avatar
rusty1s committed
118
        if colptr is not None:
rusty1s's avatar
rusty1s committed
119
            assert colptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
120
121
122
            assert colptr.device == col.device
            assert colptr.dim() == 1
            assert colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
123
            colptr = colptr.contiguous()
rusty1s's avatar
rusty1s committed
124
125
126
127
128
129

        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == col.device
            assert colcount.dim() == 1
            assert colcount.numel() == sparse_size[1]
rusty1s's avatar
rusty1s committed
130
            colcount = colcount.contiguous()
rusty1s's avatar
rusty1s committed
131

rusty1s's avatar
rusty1s committed
132
133
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
rusty1s's avatar
rusty1s committed
134
            assert csr2csc.device == col.device
rusty1s's avatar
rusty1s committed
135
            assert csr2csc.dim() == 1
rusty1s's avatar
rusty1s committed
136
            assert csr2csc.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
137
            csr2csc = csr2csc.contiguous()
rusty1s's avatar
rusty1s committed
138

rusty1s's avatar
rusty1s committed
139
140
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
rusty1s's avatar
rusty1s committed
141
            assert csc2csr.device == col.device
rusty1s's avatar
rusty1s committed
142
            assert csc2csr.dim() == 1
rusty1s's avatar
rusty1s committed
143
            assert csc2csr.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
144
            csc2csr = csc2csr.contiguous()
rusty1s's avatar
rusty1s committed
145

rusty1s's avatar
rusty1s committed
146
147
148
        self._row = row
        self._rowptr = rowptr
        self._col = col
rusty1s's avatar
rusty1s committed
149
150
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
151
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
152
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
153
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
154
155
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
156

rusty1s's avatar
rusty1s committed
157
        if not is_sorted:
rusty1s's avatar
rusty1s committed
158
159
            idx = col.new_zeros(col.numel() + 1)
            idx[1:] = sparse_size[1] * self.row() + col
rusty1s's avatar
rusty1s committed
160
            if (idx[1:] < idx[:-1]).any():
rusty1s's avatar
rusty1s committed
161
                perm = idx[1:].argsort()
rusty1s's avatar
rusty1s committed
162
163
164
165
                self._row = self.row()[perm]
                self._col = col[perm]
                if value is not None:
                    self._value = value[perm]
rusty1s's avatar
rusty1s committed
166
167
168
                self._csr2csc = None
                self._csc2csr = None

rusty1s's avatar
rusty1s committed
169
    def has_row(self) -> bool:
rusty1s's avatar
rusty1s committed
170
        return self._row is not None
rusty1s's avatar
rusty1s committed
171

rusty1s's avatar
rusty1s committed
172
    def row(self):
rusty1s's avatar
rusty1s committed
173
174
175
        row = self._row
        if row is not None:
            return row
rusty1s's avatar
rusty1s committed
176

rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        rowptr = self._rowptr
        if rowptr is not None:
            if rowptr.is_cuda:
                row = torch.ops.torch_sparse_cuda.ptr2ind(
                    rowptr, self._col.numel())
            else:
                if rowptr.is_cuda:
                    row = torch.ops.torch_sparse_cuda.ptr2ind(
                        rowptr, self._col.numel())
                else:
                    row = torch.ops.torch_sparse_cpu.ptr2ind(
                        rowptr, self._col.numel())
            self._row = row
            return row

        raise ValueError

    def has_rowptr(self) -> bool:
rusty1s's avatar
rusty1s committed
195
196
        return self._rowptr is not None

rusty1s's avatar
rusty1s committed
197
198
199
200
    def rowptr(self) -> torch.Tensor:
        rowptr = self._rowptr
        if rowptr is not None:
            return rowptr
rusty1s's avatar
rusty1s committed
201

rusty1s's avatar
rusty1s committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        row = self._row
        if row is not None:
            if row.is_cuda:
                rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
                    row, self._sparse_size[0])
            else:
                rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
                    row, self._sparse_size[0])
            self._rowptr = rowptr
            return rowptr

        raise ValueError

    def col(self) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
216
        return self._col
rusty1s's avatar
rusty1s committed
217

rusty1s's avatar
rusty1s committed
218
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
219
        return self._value is not None
rusty1s's avatar
rusty1s committed
220

rusty1s's avatar
rusty1s committed
221
    def value(self) -> Optional[torch.Tensor]:
rusty1s's avatar
rusty1s committed
222
223
        return self._value

rusty1s's avatar
rusty1s committed
224
225
226
227
228
229
230
231
    def set_value_(self, value: Optional[torch.Tensor],
                   layout: Optional[str] = None):
        if value is not None:
            if get_layout(layout) == 'csc2csr':
                value = value[self.csc2csr()]
            value = value.contiguous()
            assert value.device == self._col.device
            assert value.size(0) == self._col.numel()
rusty1s's avatar
rusty1s committed
232

rusty1s's avatar
rusty1s committed
233
234
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
235

rusty1s's avatar
rusty1s committed
236
237
238
239
240
241
242
243
    def set_value(self, value: Optional[torch.Tensor],
                  layout: Optional[str] = None):
        if value is not None:
            if get_layout(layout) == 'csc2csr':
                value = value[self.csc2csr()]
            value = value.contiguous()
            assert value.device == self._col.device
            assert value.size(0) == self._col.numel()
rusty1s's avatar
rusty1s committed
244

rusty1s's avatar
rusty1s committed
245
246
247
248
249
        return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
                             value=value, sparse_size=self._sparse_size,
                             rowcount=self._rowcount, colptr=self._colptr,
                             colcount=self._colcount, csr2csc=self._csr2csc,
                             csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
250

rusty1s's avatar
rusty1s committed
251
252
253
    def fill_value_(self, fill_value: float, dtype=Optional[torch.dtype]):
        value = torch.empty(self._col.numel(), dtype, device=self._col.device)
        return self.set_value_(value.fill_(fill_value), layout='csr')
rusty1s's avatar
rusty1s committed
254

rusty1s's avatar
rusty1s committed
255
256
257
    def fill_value(self, fill_value: float, dtype=Optional[torch.dtype]):
        value = torch.empty(self._col.numel(), dtype, device=self._col.device)
        return self.set_value(value.fill_(fill_value), layout='csr')
rusty1s's avatar
rusty1s committed
258

rusty1s's avatar
rusty1s committed
259
    def sparse_size(self) -> List[int]:
rusty1s's avatar
rusty1s committed
260
        return self._sparse_size
rusty1s's avatar
rusty1s committed
261

rusty1s's avatar
rusty1s committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def sparse_resize(self, sparse_size: List[int]):
        assert len(sparse_size) == 2
        old_sparse_size, nnz = self._sparse_size, self._col.numel()

        diff_0 = sparse_size[0] - old_sparse_size[0]
        rowcount, rowptr = self._rowcount, self._rowptr
        if diff_0 > 0:
            if rowptr is not None:
                rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
            if rowcount is not None:
                rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
        else:
            if rowptr is not None:
                rowptr = rowptr[:-diff_0]
            if rowcount is not None:
                rowcount = rowcount[:-diff_0]

        diff_1 = sparse_size[1] - old_sparse_size[1]
        colcount, colptr = self._colcount, self._colptr
        if diff_1 > 0:
            if colptr is not None:
                colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
            if colcount is not None:
                colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
        else:
            if colptr is not None:
                colptr = colptr[:-diff_1]
            if colcount is not None:
                colcount = colcount[:-diff_1]

        return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
                             value=self._value, sparse_size=sparse_size,
                             rowcount=rowcount, colptr=colptr,
                             colcount=colcount, csr2csc=self._csr2csc,
                             csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
297
298

    def has_rowcount(self) -> bool:
rusty1s's avatar
rusty1s committed
299
300
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
301
302
303
304
305
306
307
308
309
    def rowcount(self) -> torch.Tensor:
        rowcount = self._rowcount
        if rowcount is not None:
            return rowcount

        rowptr = self.rowptr()
        rowcount = rowptr[1:] - rowptr[1:]
        self._rowcount = rowcount
        return rowcount
rusty1s's avatar
rusty1s committed
310

rusty1s's avatar
rusty1s committed
311
    def has_colptr(self) -> bool:
rusty1s's avatar
rusty1s committed
312
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
313

rusty1s's avatar
rusty1s committed
314
315
316
    def colptr(self) -> torch.Tensor:
        colptr = self._colptr
        if colptr is not None:
rusty1s's avatar
rusty1s committed
317
            return colptr
rusty1s's avatar
rusty1s committed
318

rusty1s's avatar
rusty1s committed
319
320
321
322
323
324
325
326
327
328
329
        csr2csc = self._csr2csc
        if csr2csc is not None:
            colptr = torch.ops.torch_sparse_cpu.ind2ptr(
                self._col[csr2csc], self._sparse_size[1])
        else:
            colptr = self._col.new_zeros(self._sparse_size[1] + 1)
            torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
        self._colptr = colptr
        return colptr

    def has_colcount(self) -> bool:
rusty1s's avatar
rusty1s committed
330
331
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
332
333
334
335
336
337
338
339
    def colcount(self) -> torch.Tensor:
        colcount = self._colcount
        if colcount is not None:
            return colcount

        colptr = self._colptr
        if colptr is not None:
            colcount = colptr[1:] - colptr[1:]
rusty1s's avatar
rusty1s committed
340
        else:
rusty1s's avatar
rusty1s committed
341
342
343
344
345
            raise NotImplementedError
            # colcount = scatter_add(torch.ones_like(self._col), self._col,
            #                        dim_size=self._sparse_size[1])
        self._colcount = colcount
        return colcount
rusty1s's avatar
rusty1s committed
346

rusty1s's avatar
rusty1s committed
347
    def has_csr2csc(self) -> bool:
rusty1s's avatar
rusty1s committed
348
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
349

rusty1s's avatar
rusty1s committed
350
351
352
353
    def csr2csc(self) -> torch.Tensor:
        csr2csc = self._csr2csc
        if csr2csc is not None:
            return csr2csc
rusty1s's avatar
rusty1s committed
354

rusty1s's avatar
rusty1s committed
355
356
357
358
359
360
        idx = self._sparse_size[0] * self._col + self.row()
        csr2csc = idx.argsort()
        self._csr2csc = csr2csc
        return csr2csc

    def has_csc2csr(self) -> bool:
rusty1s's avatar
rusty1s committed
361
362
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
363
364
365
366
    def csc2csr(self) -> torch.Tensor:
        csc2csr = self._csc2csr
        if csc2csr is not None:
            return csc2csr
rusty1s's avatar
rusty1s committed
367

rusty1s's avatar
rusty1s committed
368
369
370
        csc2csr = self.csr2csc().argsort()
        self._csc2csr = csc2csr
        return csc2csr
rusty1s's avatar
rusty1s committed
371

rusty1s's avatar
rusty1s committed
372
373
374
375
376
377
378
379
    def is_coalesced(self) -> bool:
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_size[1] * self.row() + self._col
        return bool((idx[1:] > idx[:-1]).all())

    def coalesce(self, reduce: str = "add"):
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_size[1] * self.row() + self._col
rusty1s's avatar
rusty1s committed
380
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
rusty1s committed
381

rusty1s's avatar
rusty1s committed
382
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
383
384
            return self

rusty1s's avatar
rusty1s committed
385
386
        row = self.row()[mask]
        col = self._col[mask]
rusty1s's avatar
rusty1s committed
387

rusty1s's avatar
rusty1s committed
388
389
        value = self._value
        if value is not None:
rusty1s's avatar
rusty1s committed
390
391
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
rusty1s's avatar
rusty1s committed
392
393
            raise NotImplementedError
            # value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
394
395
            value = value[0] if isinstance(value, tuple) else value

rusty1s's avatar
rusty1s committed
396
397
398
399
400
401
402
403
404
405
406
407
408
        return SparseStorage(row=row, rowptr=None, col=col, value=value,
                             sparse_size=self._sparse_size, rowcount=None,
                             colptr=None, colcount=None, csr2csc=None,
                             csc2csr=None, is_sorted=True)

    def fill_cache_(self):
        self.row()
        self.rowptr()
        self.rowcount()
        self.colptr()
        self.colcount()
        self.csr2csc()
        self.csc2csr()
rusty1s's avatar
rusty1s committed
409
        return self
rusty1s's avatar
rusty1s committed
410

rusty1s's avatar
rusty1s committed
411
412
413
414
415
416
    def clear_cache_(self):
        self._rowcount = None
        self._colptr = None
        self._colcount = None
        self._csr2csc = None
        self._csc2csr = None
rusty1s's avatar
rusty1s committed
417
        return self
rusty1s's avatar
rusty1s committed
418

rusty1s's avatar
rusty1s committed
419
    def copy(self):
rusty1s's avatar
rusty1s committed
420
421
422
423
424
        return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
                             value=self._value, sparse_size=self._sparse_size,
                             rowcount=self._rowcount, colptr=self._colptr,
                             colcount=self._colcount, csr2csc=self._csr2csc,
                             csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
425

rusty1s's avatar
test  
rusty1s committed
426
    def clone(self):
rusty1s's avatar
rusty1s committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        row = self._row
        if row is not None:
            row = row.clone()
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.clone()
        value = self._value
        if value is not None:
            value = value.clone()
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.clone()
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.clone()
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.clone()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.clone()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.clone()
        return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(),
                             value=value, sparse_size=self._sparse_size,
                             rowcount=rowcount, colptr=colptr,
                             colcount=colcount, csr2csc=csr2csc,
                             csc2csr=csc2csr, is_sorted=True)