storage.py 9.14 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import torch_scatter
rusty1s's avatar
rusty1s committed
5
6
from torch_scatter import scatter_add, segment_add

rusty1s's avatar
rusty1s committed
7
__cache_flag__ = {'enabled': True}
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def is_cache_enabled():
    return __cache_flag__['enabled']


def set_cache_enabled(mode):
    __cache_flag__['enabled'] = mode


class no_cache(object):
    def __enter__(self):
        self.prev = is_cache_enabled()
        set_cache_enabled(False)

    def __exit__(self, *args):
        set_cache_enabled(self.prev)
        return False

    def __call__(self, func):
        def decorate_no_cache(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return decorate_no_cache
rusty1s's avatar
rusty1s committed
33
34


rusty1s's avatar
rusty1s committed
35
36
37
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
41
42
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
rusty1s's avatar
typo  
rusty1s committed
43
            if is_cache_enabled():
rusty1s's avatar
rusty1s committed
44
                setattr(obj, f'_{self.func.__name__}', value)
rusty1s's avatar
rusty1s committed
45
46
47
        return value


rusty1s's avatar
rusty1s committed
48
49
50
51
def optional(func, src):
    return func(src) if src is not None else src


rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
63
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
64
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
65
66
67
68
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
69
70
71
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
72
73
74

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
75
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
76
77

        if value is not None:
rusty1s's avatar
rusty1s committed
78
79
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
80
81
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
82
83
84
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
85
86
87
88
89
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
90
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
91
92
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
93
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
94

rusty1s's avatar
rusty1s committed
95
96
97
98
99
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
100
        if colptr is not None:
rusty1s's avatar
rusty1s committed
101
102
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
103
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
104

rusty1s's avatar
rusty1s committed
105
106
107
108
109
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
112
113
114
115
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
rusty1s's avatar
rusty1s committed
120
            if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
rusty1s's avatar
rusty1s committed
121
122
123
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
124
125
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
rusty1s committed
127
128
129
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
130
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
131
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
132
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
133
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
134
135
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
136
137

    @property
rusty1s's avatar
rusty1s committed
138
139
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
140
141

    @property
rusty1s's avatar
rusty1s committed
142
143
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
144
145

    @property
rusty1s's avatar
rusty1s committed
146
147
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
150
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
151
152

    @property
rusty1s's avatar
rusty1s committed
153
154
155
156
157
158
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
159
160
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
161
162
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
163
164
165
166

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
167
168
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
169
170
171
172
173
174
175
176
177
178
179
180
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
181
182

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
183
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
184
185
186

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
187
        self._sparse_size = sizes
rusty1s's avatar
rusty1s committed
188
        return self
rusty1s's avatar
rusty1s committed
189

rusty1s's avatar
rusty1s committed
190
191
192
193
194
    @cached_property
    def rowcount(self):
        one = torch.ones_like(self.row)
        return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])

rusty1s's avatar
rusty1s committed
195
196
    @cached_property
    def rowptr(self):
rusty1s's avatar
rusty1s committed
197
198
199
200
201
202
203
        rowcount = self.rowcount
        return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)

    @cached_property
    def colcount(self):
        one = torch.ones_like(self.col)
        return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
rusty1s's avatar
rusty1s committed
204
205
206

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
207
208
        colcount = self.colcount
        return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
209
210

    @cached_property
rusty1s's avatar
rusty1s committed
211
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
212
213
214
215
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

    @cached_property
rusty1s's avatar
rusty1s committed
216
217
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
218

rusty1s's avatar
rusty1s committed
219
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
220
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
221
222
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
223

rusty1s's avatar
rusty1s committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

        if mask.all():  # Already coalesced
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            assert reduce in ['add', 'mean', 'min', 'max']
            idx = mask.cumsum(0) - 1
            op = getattr(torch_scatter, f'scatter_{reduce}')
            value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
242

rusty1s's avatar
rusty1s committed
243
244
245
246
247
248
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
249
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
250
251
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
252
        return self
rusty1s's avatar
rusty1s committed
253

rusty1s's avatar
rusty1s committed
254
255
256
257
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
258

rusty1s's avatar
rusty1s committed
259
260
261
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
262
263
264
265
266
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
267
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
268
269
        return new_storage

rusty1s's avatar
rusty1s committed
270
271
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
272
        return self
rusty1s's avatar
rusty1s committed
273

rusty1s's avatar
rusty1s committed
274
275
276
277
278
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
279
            self._rowcount,
rusty1s's avatar
rusty1s committed
280
            self._rowptr,
rusty1s's avatar
rusty1s committed
281
            self._colcount,
rusty1s's avatar
rusty1s committed
282
            self._colptr,
rusty1s's avatar
rusty1s committed
283
284
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
285
286
287
288
289
290
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
291
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
292
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
293
        return self
rusty1s's avatar
rusty1s committed
294
295
296
297
298
299

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
300
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
301
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
302
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
303
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
304
305
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
306
307
308
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
309
310
311
312
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
313
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
314
        return data