storage.py 9.54 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import torch_scatter
rusty1s's avatar
rusty1s committed
5
6
7
from torch_scatter import scatter_add, segment_add


rusty1s's avatar
rusty1s committed
8
9
def optional(func, src):
    return func(src) if src is not None else src
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
13
14
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
            setattr(obj, f'_{self.func.__name__}', value)
        return value


rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
33
34
35
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
36
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
37
38
39
40
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
48
49
50
51
    def __init__(self,
                 index,
                 value=None,
                 sparse_size=None,
                 rowcount=None,
                 rowptr=None,
                 colcount=None,
                 colptr=None,
                 csr2csc=None,
                 csc2csr=None,
                 is_sorted=False):
rusty1s's avatar
rusty1s committed
52
53
54

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
55
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
56
57

        if value is not None:
rusty1s's avatar
rusty1s committed
58
59
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
60
61
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
62
63
64
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
65
66
67
68
69
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
70
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
71
72
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
73
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
rusty1s committed
75
76
77
78
79
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
80
        if colptr is not None:
rusty1s's avatar
rusty1s committed
81
82
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
83
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
86
87
88
89
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
92
93
94
95
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
96

rusty1s's avatar
rusty1s committed
97
98
99
100
101
102
103
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
            if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
104
105
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
106

rusty1s's avatar
rusty1s committed
107
108
109
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
110
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
111
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
112
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
113
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
114
115
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
116
117

    @property
rusty1s's avatar
rusty1s committed
118
119
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
120
121

    @property
rusty1s's avatar
rusty1s committed
122
123
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
124
125

    @property
rusty1s's avatar
rusty1s committed
126
127
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
128

rusty1s's avatar
rusty1s committed
129
130
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
131
132

    @property
rusty1s's avatar
rusty1s committed
133
134
135
136
137
138
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
139
140
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
141
142
        self._value = value
        return self
rusty1s's avatar
rusty1s committed
143
144
145
146

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
147
148
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
149
150
151
152
153
154
155
156
157
158
159
160
        return self.__class__(
            self._index,
            value,
            self._sparse_size,
            self._rowcount,
            self._rowptr,
            self._colcount,
            self._colptr,
            self._csr2csc,
            self._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
161
162

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
163
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
164
165
166

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
167
        self._sparse_size == sizes
rusty1s's avatar
rusty1s committed
168
        return self
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
171
172
173
174
    @cached_property
    def rowcount(self):
        one = torch.ones_like(self.row)
        return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])

rusty1s's avatar
rusty1s committed
175
176
    @cached_property
    def rowptr(self):
rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
        rowcount = self.rowcount
        return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)

    @cached_property
    def colcount(self):
        one = torch.ones_like(self.col)
        return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
rusty1s's avatar
rusty1s committed
184
185
186

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
187
188
        colcount = self.colcount
        return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
189
190

    @cached_property
rusty1s's avatar
rusty1s committed
191
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
192
193
194
195
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

    @cached_property
rusty1s's avatar
rusty1s committed
196
197
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
198

rusty1s's avatar
rusty1s committed
199
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
200
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
201
202
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
203

rusty1s's avatar
rusty1s committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

        if mask.all():  # Already coalesced
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            assert reduce in ['add', 'mean', 'min', 'max']
            idx = mask.cumsum(0) - 1
            op = getattr(torch_scatter, f'scatter_{reduce}')
            value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
222

rusty1s's avatar
rusty1s committed
223
224
225
226
227
228
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
229
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
230
231
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
232
        return self
rusty1s's avatar
rusty1s committed
233

rusty1s's avatar
rusty1s committed
234
235
236
237
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
238

rusty1s's avatar
rusty1s committed
239
240
241
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
242
243
244
245
246
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
247
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
248
249
        return new_storage

rusty1s's avatar
rusty1s committed
250
251
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
252
        return self
rusty1s's avatar
rusty1s committed
253

rusty1s's avatar
rusty1s committed
254
255
256
257
258
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
259
            self._rowcount,
rusty1s's avatar
rusty1s committed
260
            self._rowptr,
rusty1s's avatar
rusty1s committed
261
            self._colcount,
rusty1s's avatar
rusty1s committed
262
            self._colptr,
rusty1s's avatar
rusty1s committed
263
264
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
265
266
267
268
269
270
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
271
272
        for key in self.cached_keys():
            setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
rusty1s's avatar
rusty1s committed
273
        return self
rusty1s's avatar
rusty1s committed
274
275
276
277
278
279

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
280
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
281
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
282
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
283
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
284
285
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
286
287
288
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
289
290
291
292
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
293
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
294
295
        return data

rusty1s's avatar
rusty1s committed
296
297

if __name__ == '__main__':
rusty1s's avatar
test  
rusty1s committed
298
    from torch_geometric.datasets import Reddit, Planetoid  # noqa
rusty1s's avatar
rusty1s committed
299
    import time  # noqa
rusty1s's avatar
test  
rusty1s committed
300
    import copy  # noqa
rusty1s's avatar
rusty1s committed
301
302

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
rusty1s's avatar
test  
rusty1s committed
303
304
    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/Cora', 'Cora')
rusty1s's avatar
rusty1s committed
305
306
    data = dataset[0].to(device)
    edge_index = data.edge_index
rusty1s's avatar
sorting  
rusty1s committed
307

rusty1s's avatar
rusty1s committed
308
309
    storage = SparseStorage(edge_index, is_sorted=True)
    t = time.perf_counter()
rusty1s's avatar
rusty1s committed
310
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
311
312
    print(time.perf_counter() - t)
    t = time.perf_counter()
rusty1s's avatar
test  
rusty1s committed
313
    storage.clear_cache_()
rusty1s's avatar
rusty1s committed
314
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
315
    print(time.perf_counter() - t)
rusty1s's avatar
test  
rusty1s committed
316
    print(storage)
rusty1s's avatar
rusty1s committed
317
    # storage = storage.clone()
rusty1s's avatar
test  
rusty1s committed
318
    # print(storage)
rusty1s's avatar
rusty1s committed
319
320
321
322
323
324
    storage = copy.copy(storage)
    print(storage)
    print(id(storage))
    storage = copy.deepcopy(storage)
    print(storage)
    storage.fill_cache_()
rusty1s's avatar
test  
rusty1s committed
325
    storage.clear_cache_()