storage.py 9.16 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import torch_scatter
rusty1s's avatar
rusty1s committed
5
6
7
from torch_scatter import scatter_add, segment_add


rusty1s's avatar
rusty1s committed
8
9
def optional(func, src):
    return func(src) if src is not None else src
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
13
14
class cached_property(object):
    def __init__(self, func):
        self.func = func
rusty1s's avatar
sorting  
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
    def __get__(self, obj, cls):
        value = getattr(obj, f'_{self.func.__name__}', None)
        if value is None:
            value = self.func(obj)
            setattr(obj, f'_{self.func.__name__}', value)
        return value


rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
33
34
35
layouts = ['coo', 'csr', 'csc']


def get_layout(layout=None):
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout in layouts
    return layout


rusty1s's avatar
rusty1s committed
36
class SparseStorage(object):
rusty1s's avatar
rusty1s committed
37
38
39
40
    cache_keys = [
        'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
    ]

rusty1s's avatar
rusty1s committed
41
42
43
    def __init__(self, index, value=None, sparse_size=None, rowcount=None,
                 rowptr=None, colcount=None, colptr=None, csr2csc=None,
                 csc2csr=None, is_sorted=False):
rusty1s's avatar
rusty1s committed
44
45
46

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
47
        index = index.contiguous()
rusty1s's avatar
rusty1s committed
48
49

        if value is not None:
rusty1s's avatar
rusty1s committed
50
51
            assert value.device == index.device
            assert value.size(0) == index.size(1)
rusty1s's avatar
rusty1s committed
52
53
            value = value.contiguous()

rusty1s's avatar
rusty1s committed
54
55
56
        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())

rusty1s's avatar
rusty1s committed
57
58
59
60
61
        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == index.device
            assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]

rusty1s's avatar
rusty1s committed
62
        if rowptr is not None:
rusty1s's avatar
rusty1s committed
63
64
            assert rowptr.dtype == torch.long
            assert rowptr.device == index.device
rusty1s's avatar
rusty1s committed
65
            assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
70
71
        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == index.device
            assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]

rusty1s's avatar
rusty1s committed
72
        if colptr is not None:
rusty1s's avatar
rusty1s committed
73
74
            assert colptr.dtype == torch.long
            assert colptr.device == index.device
rusty1s's avatar
rusty1s committed
75
            assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
79
80
81
        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == index.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
85
86
87
        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == index.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == index.size(1)
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
95
        if not is_sorted:
            idx = sparse_size[1] * index[0] + index[1]
            # Only sort if necessary...
            if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
                perm = idx.argsort()
                index = index[:, perm]
                value = None if value is None else value[perm]
rusty1s's avatar
rusty1s committed
96
97
                csr2csc = None
                csc2csr = None
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
100
101
        self._index = index
        self._value = value
        self._sparse_size = sparse_size
rusty1s's avatar
rusty1s committed
102
        self._rowcount = rowcount
rusty1s's avatar
rusty1s committed
103
        self._rowptr = rowptr
rusty1s's avatar
rusty1s committed
104
        self._colcount = colcount
rusty1s's avatar
rusty1s committed
105
        self._colptr = colptr
rusty1s's avatar
rusty1s committed
106
107
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr
rusty1s's avatar
rusty1s committed
108
109

    @property
rusty1s's avatar
rusty1s committed
110
111
    def index(self):
        return self._index
rusty1s's avatar
rusty1s committed
112
113

    @property
rusty1s's avatar
rusty1s committed
114
115
    def row(self):
        return self._index[0]
rusty1s's avatar
rusty1s committed
116
117

    @property
rusty1s's avatar
rusty1s committed
118
119
    def col(self):
        return self._index[1]
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
    def has_value(self):
        return self._value is not None
rusty1s's avatar
rusty1s committed
123
124

    @property
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
    def value(self):
        return self._value

    def set_value_(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
131
132
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
133
134
135
136
137
        return self.apply_value_(lambda x: value)

    def set_value(self, value, layout=None):
        assert value.device == self._index.device
        assert value.size(0) == self._index.size(1)
rusty1s's avatar
rusty1s committed
138
139
        if value is not None and get_layout(layout) == 'csc':
            value = value[self.csc2csr]
rusty1s's avatar
rusty1s committed
140
        return self.apply_value(lambda x: value)
rusty1s's avatar
rusty1s committed
141
142

    def sparse_size(self, dim=None):
rusty1s's avatar
rusty1s committed
143
        return self._sparse_size if dim is None else self._sparse_size[dim]
rusty1s's avatar
rusty1s committed
144
145
146

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
rusty1s's avatar
rusty1s committed
147
        self._sparse_size == sizes
rusty1s's avatar
rusty1s committed
148
        return self
rusty1s's avatar
rusty1s committed
149

rusty1s's avatar
rusty1s committed
150
151
152
153
154
    @cached_property
    def rowcount(self):
        one = torch.ones_like(self.row)
        return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])

rusty1s's avatar
rusty1s committed
155
156
    @cached_property
    def rowptr(self):
rusty1s's avatar
rusty1s committed
157
158
159
160
161
162
163
        rowcount = self.rowcount
        return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)

    @cached_property
    def colcount(self):
        one = torch.ones_like(self.col)
        return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
rusty1s's avatar
rusty1s committed
164
165
166

    @cached_property
    def colptr(self):
rusty1s's avatar
rusty1s committed
167
168
        colcount = self.colcount
        return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
169
170

    @cached_property
rusty1s's avatar
rusty1s committed
171
    def csr2csc(self):
rusty1s's avatar
rusty1s committed
172
173
174
175
        idx = self._sparse_size[0] * self.col + self.row
        return idx.argsort()

    @cached_property
rusty1s's avatar
rusty1s committed
176
177
    def csc2csr(self):
        return self.csr2csc.argsort()
rusty1s's avatar
rusty1s committed
178

rusty1s's avatar
rusty1s committed
179
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
180
        idx = self.sparse_size(1) * self.row + self.col
rusty1s's avatar
rusty1s committed
181
182
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
        return mask.all().item()
rusty1s's avatar
rusty1s committed
183

rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def coalesce(self, reduce='add'):
        idx = self.sparse_size(1) * self.row + self.col
        mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)

        if mask.all():  # Already coalesced
            return self

        index = self.index[:, mask]

        value = self.value
        if self.has_value():
            assert reduce in ['add', 'mean', 'min', 'max']
            idx = mask.cumsum(0) - 1
            op = getattr(torch_scatter, f'scatter_{reduce}')
            value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
            value = value[0] if isinstance(value, tuple) else value

        return self.__class__(index, value, self.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
202

rusty1s's avatar
rusty1s committed
203
204
205
206
207
208
    def cached_keys(self):
        return [
            key for key in self.cache_keys
            if getattr(self, f'_{key}', None) is not None
        ]

rusty1s's avatar
rusty1s committed
209
    def fill_cache_(self, *args):
rusty1s's avatar
rusty1s committed
210
211
        for arg in args or self.cache_keys:
            getattr(self, arg)
rusty1s's avatar
rusty1s committed
212
        return self
rusty1s's avatar
rusty1s committed
213

rusty1s's avatar
rusty1s committed
214
215
216
217
    def clear_cache_(self, *args):
        for arg in args or self.cache_keys:
            setattr(self, f'_{arg}', None)
        return self
rusty1s's avatar
rusty1s committed
218

rusty1s's avatar
rusty1s committed
219
220
221
    def __copy__(self):
        return self.apply(lambda x: x)

rusty1s's avatar
test  
rusty1s committed
222
223
224
225
226
    def clone(self):
        return self.apply(lambda x: x.clone())

    def __deepcopy__(self, memo):
        new_storage = self.clone()
rusty1s's avatar
rusty1s committed
227
        memo[id(self)] = new_storage
rusty1s's avatar
test  
rusty1s committed
228
229
        return new_storage

rusty1s's avatar
rusty1s committed
230
231
    def apply_value_(self, func):
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
232
        return self
rusty1s's avatar
rusty1s committed
233

rusty1s's avatar
rusty1s committed
234
235
236
237
238
    def apply_value(self, func):
        return self.__class__(
            self._index,
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
239
            self._rowcount,
rusty1s's avatar
rusty1s committed
240
            self._rowptr,
rusty1s's avatar
rusty1s committed
241
            self._colcount,
rusty1s's avatar
rusty1s committed
242
            self._colptr,
rusty1s's avatar
rusty1s committed
243
244
            self._csr2csc,
            self._csc2csr,
rusty1s's avatar
rusty1s committed
245
246
247
248
249
250
            is_sorted=True,
        )

    def apply_(self, func):
        self._index = func(self._index)
        self._value = optional(func, self._value)
rusty1s's avatar
rusty1s committed
251
252
        for key in self.cached_keys():
            setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
rusty1s's avatar
rusty1s committed
253
        return self
rusty1s's avatar
rusty1s committed
254
255
256
257
258
259

    def apply(self, func):
        return self.__class__(
            func(self._index),
            optional(func, self._value),
            self._sparse_size,
rusty1s's avatar
rusty1s committed
260
            optional(func, self._rowcount),
rusty1s's avatar
rusty1s committed
261
            optional(func, self._rowptr),
rusty1s's avatar
rusty1s committed
262
            optional(func, self._colcount),
rusty1s's avatar
rusty1s committed
263
            optional(func, self._colptr),
rusty1s's avatar
rusty1s committed
264
265
            optional(func, self._csr2csc),
            optional(func, self._csc2csr),
rusty1s's avatar
rusty1s committed
266
267
268
            is_sorted=True,
        )

rusty1s's avatar
rusty1s committed
269
270
271
272
    def map(self, func):
        data = [func(self.index)]
        if self.has_value():
            data += [func(self.value)]
rusty1s's avatar
rusty1s committed
273
        data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
rusty1s's avatar
rusty1s committed
274
275
        return data

rusty1s's avatar
rusty1s committed
276
277

if __name__ == '__main__':
rusty1s's avatar
test  
rusty1s committed
278
    from torch_geometric.datasets import Reddit, Planetoid  # noqa
rusty1s's avatar
rusty1s committed
279
    import time  # noqa
rusty1s's avatar
test  
rusty1s committed
280
    import copy  # noqa
rusty1s's avatar
rusty1s committed
281
282

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
rusty1s's avatar
test  
rusty1s committed
283
284
    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/Cora', 'Cora')
rusty1s's avatar
rusty1s committed
285
286
    data = dataset[0].to(device)
    edge_index = data.edge_index
rusty1s's avatar
sorting  
rusty1s committed
287

rusty1s's avatar
rusty1s committed
288
289
    storage = SparseStorage(edge_index, is_sorted=True)
    t = time.perf_counter()
rusty1s's avatar
rusty1s committed
290
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
291
292
    print(time.perf_counter() - t)
    t = time.perf_counter()
rusty1s's avatar
test  
rusty1s committed
293
    storage.clear_cache_()
rusty1s's avatar
rusty1s committed
294
    storage.fill_cache_()
rusty1s's avatar
rusty1s committed
295
    print(time.perf_counter() - t)
rusty1s's avatar
test  
rusty1s committed
296
    print(storage)
rusty1s's avatar
rusty1s committed
297
    # storage = storage.clone()
rusty1s's avatar
test  
rusty1s committed
298
    # print(storage)
rusty1s's avatar
rusty1s committed
299
300
301
302
303
304
    storage = copy.copy(storage)
    print(storage)
    print(id(storage))
    storage = copy.deepcopy(storage)
    print(storage)
    storage.fill_cache_()
rusty1s's avatar
test  
rusty1s committed
305
    storage.clear_cache_()