"vscode:/vscode.git/clone" did not exist on "2e2584fc66cceef6acc038c309e0e98f394428ec"
storage.py 13.4 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
from torch_sparse import convert_cpu
7

rusty1s's avatar
rusty1s committed
8
try:
rusty1s's avatar
rusty1s committed
9
    from torch_sparse import convert_cuda
rusty1s's avatar
rusty1s committed
10
except ImportError:
rusty1s's avatar
rusty1s committed
11
    convert_cuda = None
12

rusty1s's avatar
typo  
rusty1s committed
13
__cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16

def is_cache_enabled():
rusty1s's avatar
typo  
rusty1s committed
17
    return __cache__['enabled']
rusty1s's avatar
rusty1s committed
18
19
20


def set_cache_enabled(mode):
rusty1s's avatar
typo  
rusty1s committed
21
    __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
39
40


rusty1s's avatar
rusty1s committed
41
42
43
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
47
48
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
rusty1s's avatar
typo  
rusty1s committed
49
            if is_cache_enabled():
rusty1s's avatar
rusty1s committed
50
                setattr(obj, f'_{self.func.__name__}', value)
rusty1s's avatar
rusty1s committed
51
52
53
        return value


rusty1s's avatar
rusty1s committed
54
55
56
57
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
68
69
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
70
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
71
    cache_keys = ['rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr']
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
75
    def __init__(self, row=None, rowptr=None, col=None, value=None,
                 sparse_size=None, rowcount=None, colptr=None, colcount=None,
                 csr2csc=None, csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
79
80
        assert row is not None or rowptr is not None
        assert col is not None
        assert col.dtype == torch.long
        assert col.dim() == 1
rusty1s's avatar
rusty1s committed
81
        col = col.contiguous()
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
        if sparse_size is None:
rusty1s's avatar
rusty1s committed
84
            M = rowptr.numel() - 1 if row is None else row.max().item() + 1
rusty1s's avatar
rusty1s committed
85
86
            N = col.max().item() + 1
            sparse_size = torch.Size([M, N])
rusty1s's avatar
rusty1s committed
87

rusty1s's avatar
rusty1s committed
88
89
90
91
92
        if row is not None:
            assert row.dtype == torch.long
            assert row.device == col.device
            assert row.dim() == 1
            assert row.numel() == col.numel()
rusty1s's avatar
rusty1s committed
93
            row = row.contiguous()
rusty1s's avatar
rusty1s committed
94

rusty1s's avatar
rusty1s committed
95
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
96
            assert rowptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
97
98
99
            assert rowptr.device == col.device
            assert rowptr.dim() == 1
            assert rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
100
            rowptr = rowptr.contiguous()
rusty1s's avatar
rusty1s committed
101

rusty1s's avatar
rusty1s committed
102
103
104
        if value is not None:
            assert value.device == col.device
            assert value.size(0) == col.size(0)
rusty1s's avatar
rusty1s committed
105
            value = value.contiguous()
rusty1s's avatar
rusty1s committed
106
107
108
109
110
111

        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == col.device
            assert rowcount.dim() == 1
            assert rowcount.numel() == sparse_size[0]
rusty1s's avatar
rusty1s committed
112
            rowcount = rowcount.contiguous()
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
        if colptr is not None:
rusty1s's avatar
rusty1s committed
115
            assert colptr.dtype == torch.long
rusty1s's avatar
rusty1s committed
116
117
118
            assert colptr.device == col.device
            assert colptr.dim() == 1
            assert colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
119
            colptr = colptr.contiguous()
rusty1s's avatar
rusty1s committed
120
121
122
123
124
125

        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == col.device
            assert colcount.dim() == 1
            assert colcount.numel() == sparse_size[1]
rusty1s's avatar
rusty1s committed
126
            colcount = colcount.contiguous()
rusty1s's avatar
rusty1s committed
127

rusty1s's avatar
rusty1s committed
128
129
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
rusty1s's avatar
rusty1s committed
130
            assert csr2csc.device == col.device
rusty1s's avatar
rusty1s committed
131
            assert csr2csc.dim() == 1
rusty1s's avatar
rusty1s committed
132
            assert csr2csc.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
133
            csr2csc = csr2csc.contiguous()
rusty1s's avatar
rusty1s committed
134

rusty1s's avatar
rusty1s committed
135
136
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
rusty1s's avatar
rusty1s committed
137
            assert csc2csr.device == col.device
rusty1s's avatar
rusty1s committed
138
            assert csc2csr.dim() == 1
rusty1s's avatar
rusty1s committed
139
            assert csc2csr.numel() == col.size(0)
rusty1s's avatar
rusty1s committed
140
            csc2csr = csc2csr.contiguous()
rusty1s's avatar
rusty1s committed
141

rusty1s's avatar
rusty1s committed
142
143
144
        self._row = row
        self._rowptr = rowptr
        self._col = col
rusty1s's avatar
rusty1s committed
145
146
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
147
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
148
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
149
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
150
151
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
152

rusty1s's avatar
rusty1s committed
153
154
155
156
        if not is_sorted:
            idx = self.col.new_zeros(col.numel() + 1)
            idx[1:] = sparse_size[1] * self.row + self.col
            if (idx[1:] < idx[:-1]).any():
rusty1s's avatar
rusty1s committed
157
                perm = idx[1:].argsort()
rusty1s's avatar
rusty1s committed
158
159
160
161
162
163
164
165
                self._row = self.row[perm]
                self._col = self.col[perm]
                self._value = self.value[perm] if self.has_value() else None
                self._csr2csc = None
                self._csc2csr = None

    def has_row(self):
        return self._row is not None
rusty1s's avatar
rusty1s committed
166
167

    @property
rusty1s's avatar
rusty1s committed
168
    def row(self):
rusty1s's avatar
rusty1s committed
169
        if self._row is None:
rusty1s's avatar
rusty1s committed
170
            func = convert_cuda if self.rowptr.is_cuda else convert_cpu
rusty1s's avatar
rusty1s committed
171
            self._row = func.ptr2ind(self.rowptr, self.col.numel())
rusty1s's avatar
rusty1s committed
172
173
174
175
176
177
178
179
        return self._row

    def has_rowptr(self):
        return self._rowptr is not None

    @property
    def rowptr(self):
        if self._rowptr is None:
rusty1s's avatar
rusty1s committed
180
181
            func = convert_cuda if self.row.is_cuda else convert_cpu
            self._rowptr = func.ind2ptr(self.row, self.sparse_size[0])
rusty1s's avatar
rusty1s committed
182
        return self._rowptr
rusty1s's avatar
rusty1s committed
183
184

    @property
rusty1s's avatar
rusty1s committed
185
    def col(self):
rusty1s's avatar
rusty1s committed
186
        return self._col
rusty1s's avatar
rusty1s committed
187

rusty1s's avatar
rusty1s committed
188
189
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
190
191

    @property
rusty1s's avatar
rusty1s committed
192
193
194
    def value(self):
        return self._value

rusty1s's avatar
rusty1s committed
195
    def set_value_(self, value, layout=None, dtype=None):
rusty1s's avatar
rusty1s committed
196
        if isinstance(value, int) or isinstance(value, float):
rusty1s's avatar
rusty1s committed
197
            value = torch.full((self.col.numel(), ), dtype=dtype,
rusty1s's avatar
rusty1s committed
198
199
                               device=self.col.device)

rusty1s's avatar
rusty1s committed
200
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
201
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
202

rusty1s's avatar
rusty1s committed
203
        if torch.is_tensor(value):
rusty1s's avatar
rusty1s committed
204
205
206
207
            value = value if dtype is None else value.to(dtype)
            assert value.device == self.col.device
            assert value.size(0) == self.col.numel()

rusty1s's avatar
rusty1s committed
208
209
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
210

rusty1s's avatar
rusty1s committed
211
    def set_value(self, value, layout=None, dtype=None):
rusty1s's avatar
rusty1s committed
212
        if isinstance(value, int) or isinstance(value, float):
rusty1s's avatar
rusty1s committed
213
            value = torch.full((self.col.numel(), ), dtype=dtype,
rusty1s's avatar
rusty1s committed
214
215
                               device=self.col.device)

rusty1s's avatar
rusty1s committed
216
217
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
218

rusty1s's avatar
rusty1s committed
219
        if torch.is_tensor(value):
rusty1s's avatar
rusty1s committed
220
221
222
            value = value if dtype is None else value.to(dtype)
            assert value.device == self.col.device
            assert value.size(0) == self.col.numel()
rusty1s's avatar
rusty1s committed
223

rusty1s's avatar
rusty1s committed
224
225
226
227
228
229
230
231
232
        return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
                              value=value, sparse_size=self._sparse_size,
                              rowcount=self._rowcount, colptr=self._colptr,
                              colcount=self._colcount, csr2csc=self._csr2csc,
                              csc2csr=self._csc2csr, is_sorted=True)

    @property
    def sparse_size(self):
        return self._sparse_size
rusty1s's avatar
rusty1s committed
233

rusty1s's avatar
rusty1s committed
234
    def sparse_resize(self, *sizes):
rusty1s's avatar
rusty1s committed
235
        old_sparse_size, nnz = self.sparse_size, self.col.numel()
rusty1s's avatar
rusty1s committed
236

rusty1s's avatar
rusty1s committed
237
        diff_0 = sizes[0] - old_sparse_size[0]
rusty1s's avatar
rusty1s committed
238
239
        rowcount, rowptr = self._rowcount, self._rowptr
        if diff_0 > 0:
rusty1s's avatar
rusty1s committed
240
            if rowptr is not None:
rusty1s's avatar
rusty1s committed
241
                rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
rusty1s's avatar
rusty1s committed
242
243
            if rowcount is not None:
                rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
rusty1s's avatar
rusty1s committed
244
        else:
rusty1s's avatar
rusty1s committed
245
            if rowptr is not None:
rusty1s's avatar
rusty1s committed
246
                rowptr = rowptr[:-diff_0]
rusty1s's avatar
rusty1s committed
247
248
            if rowcount is not None:
                rowcount = rowcount[:-diff_0]
rusty1s's avatar
rusty1s committed
249

rusty1s's avatar
rusty1s committed
250
        diff_1 = sizes[1] - old_sparse_size[1]
rusty1s's avatar
rusty1s committed
251
252
        colcount, colptr = self._colcount, self._colptr
        if diff_1 > 0:
rusty1s's avatar
rusty1s committed
253
            if colptr is not None:
rusty1s's avatar
rusty1s committed
254
                colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
rusty1s's avatar
rusty1s committed
255
256
            if colcount is not None:
                colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
rusty1s's avatar
rusty1s committed
257
        else:
rusty1s's avatar
rusty1s committed
258
            if colptr is not None:
rusty1s's avatar
rusty1s committed
259
                colptr = colptr[:-diff_1]
rusty1s's avatar
rusty1s committed
260
261
            if colcount is not None:
                colcount = colcount[:-diff_1]
rusty1s's avatar
rusty1s committed
262

rusty1s's avatar
rusty1s committed
263
264
265
266
267
        return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
                              value=self.value, sparse_size=sizes,
                              rowcount=rowcount, colptr=colptr,
                              colcount=colcount, csr2csc=self._csr2csc,
                              csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
268

rusty1s's avatar
rusty1s committed
269
270
271
    def has_rowcount(self):
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
272
273
    @cached_property
    def rowcount(self):
rusty1s's avatar
rusty1s committed
274
        return self.rowptr[1:] - self.rowptr[:-1]
rusty1s's avatar
rusty1s committed
275

rusty1s's avatar
rusty1s committed
276
277
    def has_colptr(self):
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
278

rusty1s's avatar
rusty1s committed
279
    @cached_property
rusty1s's avatar
rusty1s committed
280
281
    def colptr(self):
        if self.has_csr2csc():
rusty1s's avatar
rusty1s committed
282
283
            func = convert_cuda if self.col.is_cuda else convert_cpu
            return func.ind2ptr(self.col[self.csr2csc], self.sparse_size[1])
rusty1s's avatar
rusty1s committed
284
285
286
287
        else:
            colptr = self.col.new_zeros(self.sparse_size[1] + 1)
            torch.cumsum(self.colcount, dim=0, out=colptr[1:])
            return colptr
rusty1s's avatar
rusty1s committed
288

rusty1s's avatar
rusty1s committed
289
290
291
    def has_colcount(self):
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
292
293
    @cached_property
    def colcount(self):
rusty1s's avatar
typos  
rusty1s committed
294
        if self.has_colptr():
rusty1s's avatar
rusty1s committed
295
            return self.colptr[1:] - self.colptr[:-1]
rusty1s's avatar
rusty1s committed
296
        else:
rusty1s's avatar
rusty1s committed
297
298
            return scatter_add(torch.ones_like(self.col), self.col,
                               dim_size=self.sparse_size[1])
rusty1s's avatar
rusty1s committed
299
300
301

    def has_csr2csc(self):
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
302
303

    @cached_property
rusty1s's avatar
rusty1s committed
304
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
305
        idx = self.sparse_size[0] * self.col + self.row
rusty1s's avatar
rusty1s committed
306
307
        return idx.argsort()

rusty1s's avatar
rusty1s committed
308
309
310
    def has_csc2csr(self):
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
311
    @cached_property
rusty1s's avatar
rusty1s committed
312
313
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
314

rusty1s's avatar
rusty1s committed
315
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
316
        idx = self.col.new_full((self.col.numel() + 1, ), -1)
rusty1s's avatar
rusty1s committed
317
318
        idx[1:] = self.sparse_size[1] * self.row + self.col
        return (idx[1:] > idx[:-1]).all().item()
rusty1s's avatar
rusty1s committed
319

rusty1s's avatar
rusty1s committed
320
    def coalesce(self, reduce='add'):
rusty1s's avatar
rusty1s committed
321
        idx = self.col.new_full((self.col.numel() + 1, ), -1)
rusty1s's avatar
rusty1s committed
322
323
        idx[1:] = self.sparse_size[1] * self.row + self.col
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
rusty1s committed
324

rusty1s's avatar
rusty1s committed
325
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
326
327
            return self

rusty1s's avatar
rusty1s committed
328
329
        row = self.row[mask]
        col = self.col[mask]
rusty1s's avatar
rusty1s committed
330
331
332

        value = self.value
        if self.has_value():
rusty1s's avatar
rusty1s committed
333
334
335
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
336
337
            value = value[0] if isinstance(value, tuple) else value

rusty1s's avatar
rusty1s committed
338
339
        return self.__class__(row=row, col=col, value=value,
                              sparse_size=self.sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
340

rusty1s's avatar
rusty1s committed
341
342
343
344
345
346
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
347
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
348
        for arg in args or self.cache_keys + ['row', 'rowptr']:
rusty1s's avatar
rusty1s committed
349
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
350
        return self
rusty1s's avatar
rusty1s committed
351

rusty1s's avatar
rusty1s committed
352
353
354
355
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
356

rusty1s's avatar
rusty1s committed
357
358
359
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
360
361
362
363
364
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
365
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
366
367
        return new_storage

rusty1s's avatar
rusty1s committed
368
    def apply_value_(self, func):
rusty1s's avatar
rusty1s committed
369
        self._value = optional(func, self.value)
rusty1s's avatar
rusty1s committed
370
        return self
rusty1s's avatar
rusty1s committed
371

rusty1s's avatar
rusty1s committed
372
    def apply_value(self, func):
rusty1s's avatar
rusty1s committed
373
374
375
376
377
378
        return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
                              value=optional(func, self.value),
                              sparse_size=self.sparse_size,
                              rowcount=self._rowcount, colptr=self._colptr,
                              colcount=self._colcount, csr2csc=self._csr2csc,
                              csc2csr=self._csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
379
380

    def apply_(self, func):
rusty1s's avatar
rusty1s committed
381
382
383
384
        self._row = optional(func, self._row)
        self._rowptr = optional(func, self._rowptr)
        self._col = func(self.col)
        self._value = optional(func, self.value)
rusty1s's avatar
rusty1s committed
385
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
386
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
387
        return self
rusty1s's avatar
rusty1s committed
388
389
390

    def apply(self, func):
        return self.__class__(
rusty1s's avatar
rusty1s committed
391
392
393
394
395
396
397
398
399
400
            row=optional(func, self._row),
            rowptr=optional(func, self._rowptr),
            col=func(self.col),
            value=optional(func, self.value),
            sparse_size=self.sparse_size,
            rowcount=optional(func, self._rowcount),
            colptr=optional(func, self._colptr),
            colcount=optional(func, self._colcount),
            csr2csc=optional(func, self._csr2csc),
            csc2csr=optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
401
402
403
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
404
    def map(self, func):
rusty1s's avatar
rusty1s committed
405
406
407
408
409
410
        data = []
        if self.has_row():
            data += [func(self.row)]
        if self.has_rowptr():
            data += [func(self.rowptr)]
        data += [func(self.col)]
rusty1s's avatar
rusty1s committed
411
412
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
413
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
414
        return data