storage.py 8.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import torch_scatter
rusty1s's avatar
rusty1s committed
5
6
7
from torch_scatter import scatter_add, segment_add


rusty1s's avatar
rusty1s committed
8
9
def optional(func, src):
    return func(src) if src is not None else src
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
13
14
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
            setattr(obj, f'_{self.func.__name__}', value)
        return value


rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
33
34
35
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
36
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
37
38
39
40
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
41
42
43
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
44
45
46

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
47
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
48
49

        if value is not None:
rusty1s's avatar
rusty1s committed
50
51
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
52
53
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
54
55
56
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
57
58
59
60
61
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
62
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
63
64
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
65
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
70
71
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
72
        if colptr is not None:
rusty1s's avatar
rusty1s committed
73
74
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
75
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
79
80
81
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
85
86
87
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
90
91
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
rusty1s's avatar
rusty1s committed
92
            if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
rusty1s's avatar
rusty1s committed
93
94
95
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
96
97
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
100
101
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
102
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
103
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
104
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
105
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
106
107
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
108
109

    @property
rusty1s's avatar
rusty1s committed
110
111
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
112
113

    @property
rusty1s's avatar
rusty1s committed
114
115
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
116
117

    @property
rusty1s's avatar
rusty1s committed
118
119
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
123
124

    @property
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
131
132
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
133
134
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
135
136
137
138

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
139
140
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
141
142
143
144
145
146
147
148
149
150
151
152
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
153
154

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
155
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
156
157
158

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
159
        self._sparse_size = sizes
rusty1s's avatar
rusty1s committed
160
        return self
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
163
164
165
166
    @cached_property
    def rowcount(self):
        one = torch.ones_like(self.row)
        return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])

rusty1s's avatar
rusty1s committed
167
168
    @cached_property
    def rowptr(self):
rusty1s's avatar
rusty1s committed
169
170
171
172
173
174
175
        rowcount = self.rowcount
        return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)

    @cached_property
    def colcount(self):
        one = torch.ones_like(self.col)
        return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
rusty1s's avatar
rusty1s committed
176
177
178

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
179
180
        colcount = self.colcount
        return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
181
182

    @cached_property
rusty1s's avatar
rusty1s committed
183
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
184
185
186
187
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

    @cached_property
rusty1s's avatar
rusty1s committed
188
189
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
190

rusty1s's avatar
rusty1s committed
191
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
192
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
193
194
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
195

rusty1s's avatar
rusty1s committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

        if mask.all():  # Already coalesced
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            assert reduce in ['add', 'mean', 'min', 'max']
            idx = mask.cumsum(0) - 1
            op = getattr(torch_scatter, f'scatter_{reduce}')
            value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
214

rusty1s's avatar
rusty1s committed
215
216
217
218
219
220
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
221
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
222
223
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
224
        return self
rusty1s's avatar
rusty1s committed
225

rusty1s's avatar
rusty1s committed
226
227
228
229
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
230

rusty1s's avatar
rusty1s committed
231
232
233
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
234
235
236
237
238
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
239
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
240
241
        return new_storage

rusty1s's avatar
rusty1s committed
242
243
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
244
        return self
rusty1s's avatar
rusty1s committed
245

rusty1s's avatar
rusty1s committed
246
247
248
249
250
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
251
            self._rowcount,
rusty1s's avatar
rusty1s committed
252
            self._rowptr,
rusty1s's avatar
rusty1s committed
253
            self._colcount,
rusty1s's avatar
rusty1s committed
254
            self._colptr,
rusty1s's avatar
rusty1s committed
255
256
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
257
258
259
260
261
262
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
263
        for key in self.cached_keys():
rusty1s's avatar
rusty1s committed
264
            setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
rusty1s's avatar
rusty1s committed
265
        return self
rusty1s's avatar
rusty1s committed
266
267
268
269
270
271

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
272
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
273
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
274
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
275
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
276
277
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
278
279
280
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
281
282
283
284
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
285
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
286
        return data