storage.py 9.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import segment_csr, scatter_add
rusty1s's avatar
rusty1s committed
5

6
7
8
9
10
from torch_sparse import rowptr_cpu

if torch.cuda.is_available():
    from torch_sparse import rowptr_cuda

rusty1s's avatar
typo  
rusty1s committed
11
__cache__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14

def is_cache_enabled():
rusty1s's avatar
typo  
rusty1s committed
15
    return __cache__['enabled']
rusty1s's avatar
rusty1s committed
16
17
18


def set_cache_enabled(mode):
rusty1s's avatar
typo  
rusty1s committed
19
    __cache__['enabled'] = mode
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
40
41
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
45
46
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
rusty1s's avatar
typo  
rusty1s committed
47
            if is_cache_enabled():
rusty1s's avatar
rusty1s committed
48
                setattr(obj, f'_{self.func.__name__}', value)
rusty1s's avatar
rusty1s committed
49
50
51
        return value


rusty1s's avatar
rusty1s committed
52
53
54
55
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
68
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
69
70
71
72
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
73
74
75
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
76
77
78

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
79
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
80
81

        if value is not None:
rusty1s's avatar
rusty1s committed
82
83
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
84
85
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
86
87
88
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
89
90
91
92
93
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
94
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
95
96
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
97
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
100
101
102
103
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
104
        if colptr is not None:
rusty1s's avatar
rusty1s committed
105
106
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
107
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
110
111
112
113
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
rusty1s committed
115
116
117
118
119
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
123
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
rusty1s's avatar
rusty1s committed
124
            if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
rusty1s's avatar
rusty1s committed
125
126
127
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
128
129
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
132
133
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
134
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
135
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
136
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
137
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
138
139
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
140
141

    @property
rusty1s's avatar
rusty1s committed
142
143
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
144
145

    @property
rusty1s's avatar
rusty1s committed
146
147
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
148
149

    @property
rusty1s's avatar
rusty1s committed
150
151
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
152

rusty1s's avatar
rusty1s committed
153
154
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
155
156

    @property
rusty1s's avatar
rusty1s committed
157
158
159
160
161
162
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
163
164
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
165
166
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
167
168
169
170

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
171
172
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
173
174
175
176
177
178
179
180
181
182
183
184
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
185
186

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
187
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
188
189
190

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
191
        self._sparse_size = sizes
rusty1s's avatar
rusty1s committed
192
        return self
rusty1s's avatar
rusty1s committed
193

rusty1s's avatar
rusty1s committed
194
195
196
    def has_rowcount(self):
        return self._rowcount is not None

rusty1s's avatar
rusty1s committed
197
198
    @cached_property
    def rowcount(self):
199
200
        rowptr = self.rowptr
        return rowptr[1:] - rowptr[:-1]
rusty1s's avatar
rusty1s committed
201

rusty1s's avatar
rusty1s committed
202
203
204
    def has_rowptr(self):
        return self._rowptr is not None

rusty1s's avatar
rusty1s committed
205
206
    @cached_property
    def rowptr(self):
207
208
        func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
        return func.rowptr(self.row, self.sparse_size(0))
rusty1s's avatar
rusty1s committed
209

rusty1s's avatar
rusty1s committed
210
211
212
    def has_colcount(self):
        return self._colcount is not None

rusty1s's avatar
rusty1s committed
213
214
    @cached_property
    def colcount(self):
rusty1s's avatar
rusty1s committed
215
216
217
218
219
220
221
222
223
        if self._colptr is not None:
            colptr = self.colptr
            return colptr[1:] - colptr[:-1]
        else:
            col, dim_size = self.col, self.sparse_size(1)
            return scatter_add(torch.ones_like(col), col, dim_size=dim_size)

    def has_colptr(self):
        return self._colptr is not None
rusty1s's avatar
rusty1s committed
224
225
226

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
227
228
229
230
231
232
233
234
235
236
237
        if self._csr2csc:
            func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
            return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
        else:
            colcount = self.colcount
            colptr = colcount.new_zeros(colcount.size(0) + 1)
            torch.cumsum(colcount, dim=0, out=colptr[1:])
            return colptr

    def has_csr2csc(self):
        return self._csr2csc is not None
rusty1s's avatar
rusty1s committed
238
239

    @cached_property
rusty1s's avatar
rusty1s committed
240
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
241
242
243
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

rusty1s's avatar
rusty1s committed
244
245
246
    def has_csc2csr(self):
        return self._csc2csr is not None

rusty1s's avatar
rusty1s committed
247
    @cached_property
rusty1s's avatar
rusty1s committed
248
249
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
250

rusty1s's avatar
rusty1s committed
251
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
252
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
253
254
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
255

rusty1s's avatar
rusty1s committed
256
257
258
259
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

rusty1s's avatar
rusty1s committed
260
        if mask.all():  # Skip if indices are already coalesced.
rusty1s's avatar
rusty1s committed
261
262
263
264
265
266
267
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            idx = mask.cumsum(0) - 1
rusty1s's avatar
rusty1s committed
268
269
            dim_size = idx[-1].item() + 1
            value = segment_csr(idx, value, dim_size=dim_size, reduce=reduce)
rusty1s's avatar
rusty1s committed
270
271
272
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
273

rusty1s's avatar
rusty1s committed
274
275
276
277
278
279
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
280
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
281
282
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
283
        return self
rusty1s's avatar
rusty1s committed
284

rusty1s's avatar
rusty1s committed
285
286
287
288
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
289

rusty1s's avatar
rusty1s committed
290
291
292
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
293
294
295
296
297
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
298
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
299
300
        return new_storage

rusty1s's avatar
rusty1s committed
301
302
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
303
        return self
rusty1s's avatar
rusty1s committed
304

rusty1s's avatar
rusty1s committed
305
306
307
308
309
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
310
            self._rowcount,
rusty1s's avatar
rusty1s committed
311
            self._rowptr,
rusty1s's avatar
rusty1s committed
312
            self._colcount,
rusty1s's avatar
rusty1s committed
313
            self._colptr,
rusty1s's avatar
rusty1s committed
314
315
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
316
317
318
319
320
321
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
322
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
323
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
324
        return self
rusty1s's avatar
rusty1s committed
325
326
327
328
329
330

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
331
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
332
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
333
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
334
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
335
336
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
337
338
339
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
340
341
342
343
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
344
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
345
        return data