storage.py 11.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
5

6
7
from torch_sparse import rowptr_cpu

rusty1s's avatar
rusty1s committed
8
try:
9
    from torch_sparse import rowptr_cuda
rusty1s's avatar
rusty1s committed
10
11
except ImportError:
    rowptr_cuda = None
12

rusty1s's avatar
typo  
rusty1s committed
13
__cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16

def is_cache_enabled():
rusty1s's avatar
typo  
rusty1s committed
17
    return __cache__['enabled']
rusty1s's avatar
rusty1s committed
18
19
20


def set_cache_enabled(mode):
rusty1s's avatar
typo  
rusty1s committed
21
    __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
39
40


rusty1s's avatar
rusty1s committed
41
42
43
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
47
48
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
rusty1s's avatar
typo  
rusty1s committed
49
            if is_cache_enabled():
rusty1s's avatar
rusty1s committed
50
                setattr(obj, f'_{self.func.__name__}', value)
rusty1s's avatar
rusty1s committed
51
52
53
        return value


rusty1s's avatar
rusty1s committed
54
55
56
57
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
68
69
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
70
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
71
72
73
74
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
75
76
77
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
78
79
80

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
81
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
82
83

        if value is not None:
rusty1s's avatar
rusty1s committed
84
85
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
86
87
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
88
89
90
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
91
92
93
94
95
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
96
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
97
98
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
99
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
100

rusty1s's avatar
rusty1s committed
101
102
103
104
105
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
106
        if colptr is not None:
rusty1s's avatar
rusty1s committed
107
108
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
109
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
112
113
114
115
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
120
121
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
122

rusty1s's avatar
rusty1s committed
123
124
125
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
rusty1s's avatar
rusty1s committed
126
            if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
rusty1s's avatar
rusty1s committed
127
128
129
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
130
131
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
132

rusty1s's avatar
rusty1s committed
133
134
135
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
136
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
137
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
138
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
139
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
140
141
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
142
143

    @property
rusty1s's avatar
rusty1s committed
144
145
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
146
147

    @property
rusty1s's avatar
rusty1s committed
148
149
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
150
151

    @property
rusty1s's avatar
rusty1s committed
152
153
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
157
158

    @property
rusty1s's avatar
rusty1s committed
159
160
161
162
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
rusty1s's avatar
rusty1s committed
163
164
165
        if isinstance(value, int) or isinstance(value, float):
            value = torch.full((self.nnz(), ), device=self.index.device)
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
166
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
167
168
        assert value.device == self.index.device
        assert value.size(0) == self.index.size(1)
rusty1s's avatar
rusty1s committed
169
170
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
171
172

    def set_value(self, value, layout=None):
rusty1s's avatar
rusty1s committed
173
174
175
176
        if isinstance(value, int) or isinstance(value, float):
            value = torch.full((self.nnz(), ), device=self.index.device)
        elif torch.is_tensor(value) and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
177
178
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
179
180
181
182
183
184
185
186
187
188
189
190
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
191
192

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
193
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
194

rusty1s's avatar
rusty1s committed
195
    def sparse_resize(self, *sizes):
rusty1s's avatar
rusty1s committed
196
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        old_sizes, nnz = self.sparse_size(), self.nnz()

        diff_0 = sizes[0] - old_sizes[0]
        rowcount, rowptr = self._rowcount, self._rowptr
        if diff_0 > 0:
            if self.has_rowcount():
                rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
            if self.has_rowptr():
                rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
        else:
            if self.has_rowcount():
                rowcount = rowcount[:-diff_0]
            if self.has_rowptr():
                rowptr = rowptr[:-diff_0]

        diff_1 = sizes[1] - old_sizes[1]
        colcount, colptr = self._colcount, self._colptr
        if diff_1 > 0:
            if self.has_colcount():
                colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
            if self.has_colptr():
                colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
        else:
            if self.has_colcount():
                colcount = colcount[:-diff_1]
            if self.has_colptr():
                colptr = colptr[:-diff_1]

        return self.__class__(
            self._index,
            self._value,
            sizes,
            rowcount=rowcount,
            rowptr=rowptr,
            colcount=colcount,
            colptr=colptr,
            csr2csc=self._csr2csc,
            csc2csr=self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
237

rusty1s's avatar
rusty1s committed
238
239
240
    def has_rowcount(self):
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
241
242
    @cached_property
    def rowcount(self):
243
244
        rowptr = self.rowptr
        return rowptr[1:] - rowptr[:-1]
rusty1s's avatar
rusty1s committed
245

rusty1s's avatar
rusty1s committed
246
247
248
    def has_rowptr(self):
        return self._rowptr is not None

rusty1s's avatar
rusty1s committed
249
250
    @cached_property
    def rowptr(self):
251
252
        func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
        return func.rowptr(self.row, self.sparse_size(0))
rusty1s's avatar
rusty1s committed
253

rusty1s's avatar
rusty1s committed
254
255
256
    def has_colcount(self):
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
257
258
    @cached_property
    def colcount(self):
rusty1s's avatar
typos  
rusty1s committed
259
        if self.has_colptr():
rusty1s's avatar
rusty1s committed
260
261
262
263
264
265
266
267
            colptr = self.colptr
            return colptr[1:] - colptr[:-1]
        else:
            col, dim_size = self.col, self.sparse_size(1)
            return scatter_add(torch.ones_like(col), col, dim_size=dim_size)

    def has_colptr(self):
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
268
269
270

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
271
272
273
274
275
276
277
278
279
280
281
        if self._csr2csc:
            func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
            return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
        else:
            colcount = self.colcount
            colptr = colcount.new_zeros(colcount.size(0) + 1)
            torch.cumsum(colcount, dim=0, out=colptr[1:])
            return colptr

    def has_csr2csc(self):
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
282
283

    @cached_property
rusty1s's avatar
rusty1s committed
284
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
285
286
287
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

rusty1s's avatar
rusty1s committed
288
289
290
    def has_csc2csr(self):
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
291
    @cached_property
rusty1s's avatar
rusty1s committed
292
293
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
294

rusty1s's avatar
rusty1s committed
295
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
296
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
297
298
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
299

rusty1s's avatar
rusty1s committed
300
301
302
303
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

rusty1s's avatar
rusty1s committed
304
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
305
306
307
308
309
310
311
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            idx = mask.cumsum(0) - 1
rusty1s's avatar
rusty1s committed
312
            value = segment_csr(idx, value, reduce=reduce)
rusty1s's avatar
rusty1s committed
313
314
315
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
316

rusty1s's avatar
rusty1s committed
317
318
319
320
321
322
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
323
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
324
325
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
326
        return self
rusty1s's avatar
rusty1s committed
327

rusty1s's avatar
rusty1s committed
328
329
330
331
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
332

rusty1s's avatar
rusty1s committed
333
334
335
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
336
337
338
339
340
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
341
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
342
343
        return new_storage

rusty1s's avatar
rusty1s committed
344
345
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
346
        return self
rusty1s's avatar
rusty1s committed
347

rusty1s's avatar
rusty1s committed
348
349
350
351
352
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
353
            self._rowcount,
rusty1s's avatar
rusty1s committed
354
            self._rowptr,
rusty1s's avatar
rusty1s committed
355
            self._colcount,
rusty1s's avatar
rusty1s committed
356
            self._colptr,
rusty1s's avatar
rusty1s committed
357
358
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
359
360
361
362
363
364
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
365
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
366
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
367
        return self
rusty1s's avatar
rusty1s committed
368
369
370
371
372
373

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
374
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
375
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
376
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
377
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
378
379
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
380
381
382
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
383
384
385
386
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
387
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
388
        return data