storage.py 7.79 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import inspect

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import Size
from torch_scatter import scatter_add, segment_add


class SparseStorage(object):
    def __init__(self, row, col, value=None, sparse_size=None, rowptr=None,
                 colptr=None, arg_csr_to_csc=None, arg_csc_to_csr=None,
                 is_sorted=False):

        assert row.dtype == torch.long and col.dtype == torch.long
        assert row.device == row.device
        assert row.dim() == 1 and col.dim() == 1 and row.numel() == col.numel()

rusty1s's avatar
sorting  
rusty1s committed
17
18
19
        if sparse_size is None:
            sparse_size = Size((row.max().item() + 1, col.max().item() + 1))

rusty1s's avatar
rusty1s committed
20
        if not is_sorted:
rusty1s's avatar
sorting  
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
            idx = sparse_size[1] * row + col
            # Only sort if necessary...
            if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
                perm = idx.argsort()
                row = row[perm]
                col = col[perm]
                value = None if value is None else value[perm]
                rowptr = None
                colptr = None
                arg_csr_to_csc = None
                arg_csc_to_csr = None
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39

        if value is not None:
            assert row.device == value.device and value.size(0) == row.size(0)
            value = value.contiguous()

        ones = None
        if rowptr is None:
            ones = torch.ones_like(row)
rusty1s's avatar
sorting  
rusty1s committed
40
41
            out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
            rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
42
43
44
        else:
            assert rowptr.dtype == torch.long and rowptr.device == row.device
            assert rowptr.dim() == 1 and rowptr.size(0) == sparse_size[0] - 1
rusty1s's avatar
rusty1s committed
45
46
47

        if colptr is None:
            ones = torch.ones_like(col) if ones is None else ones
rusty1s's avatar
sorting  
rusty1s committed
48
49
            in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
            colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
rusty1s's avatar
rusty1s committed
50
51
52
        else:
            assert colptr.dtype == torch.long and colptr.device == col.device
            assert colptr.dim() == 1 and colptr.size(0) == sparse_size[1] - 1
rusty1s's avatar
rusty1s committed
53
54
55
56

        if arg_csr_to_csc is None:
            idx = sparse_size[0] * col + row
            arg_csr_to_csc = idx.argsort()
rusty1s's avatar
rusty1s committed
57
58
59
60
61
        else:
            assert arg_csr_to_csc == torch.long
            assert arg_csr_to_csc.device == row.device
            assert arg_csr_to_csc.dim() == 1
            assert arg_csr_to_csc.size(0) == row.size(0)
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
        if arg_csc_to_csr is None:
rusty1s's avatar
rusty1s committed
64
            arg_csc_to_csr = arg_csr_to_csc.argsort()
rusty1s's avatar
rusty1s committed
65
66
67
68
69
        else:
            assert arg_csc_to_csr == torch.long
            assert arg_csc_to_csr.device == row.device
            assert arg_csc_to_csr.dim() == 1
            assert arg_csc_to_csr.size(0) == row.size(0)
rusty1s's avatar
rusty1s committed
70
71
72
73
74
75
76
77
78
79
80

        self.__row = row
        self.__col = col
        self.__value = value
        self.__sparse_size = sparse_size
        self.__rowptr = rowptr
        self.__colptr = colptr
        self.__arg_csr_to_csc = arg_csr_to_csc
        self.__arg_csc_to_csr = arg_csc_to_csr

    @property
rusty1s's avatar
rusty1s committed
81
    def _row(self):
rusty1s's avatar
rusty1s committed
82
83
84
        return self.__row

    @property
rusty1s's avatar
rusty1s committed
85
    def _col(self):
rusty1s's avatar
rusty1s committed
86
87
        return self.__col

rusty1s's avatar
rusty1s committed
88
    def _index(self):
rusty1s's avatar
rusty1s committed
89
90
91
        return torch.stack([self.__row, self.__col], dim=0)

    @property
rusty1s's avatar
rusty1s committed
92
    def _rowptr(self):
rusty1s's avatar
rusty1s committed
93
94
95
        return self.__rowptr

    @property
rusty1s's avatar
rusty1s committed
96
    def _colptr(self):
rusty1s's avatar
rusty1s committed
97
98
99
        return self.__colptr

    @property
rusty1s's avatar
rusty1s committed
100
    def _arg_csr_to_csc(self):
rusty1s's avatar
rusty1s committed
101
102
103
        return self.__arg_csr_to_csc

    @property
rusty1s's avatar
rusty1s committed
104
    def _arg_csc_to_csr(self):
rusty1s's avatar
rusty1s committed
105
106
107
        return self.__arg_csc_to_csr

    @property
rusty1s's avatar
rusty1s committed
108
    def _value(self):
rusty1s's avatar
rusty1s committed
109
110
111
112
113
114
115
116
117
118
119
        return self.__value

    @property
    def has_value(self):
        return self.__value is not None

    def sparse_size(self, dim=None):
        return self.__sparse_size if dim is None else self.__sparse_size[dim]

    def size(self, dim=None):
        size = self.__sparse_size
rusty1s's avatar
rusty1s committed
120
        size += () if self.__value is None else self.__value.size()[1:]
rusty1s's avatar
rusty1s committed
121
122
123
124
125
126
127
128
129
        return size if dim is None else size[dim]

    @property
    def shape(self):
        return self.size()

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
        self.__sparse_size == sizes
rusty1s's avatar
rusty1s committed
130
        return self
rusty1s's avatar
rusty1s committed
131
132

    def clone(self):
rusty1s's avatar
rusty1s committed
133
        return self.__apply(lambda x: x.clone())
rusty1s's avatar
rusty1s committed
134

rusty1s's avatar
rusty1s committed
135
136
    def __copy__(self):
        return self.clone()
rusty1s's avatar
rusty1s committed
137
138

    def pin_memory(self):
rusty1s's avatar
rusty1s committed
139
        return self.__apply(lambda x: x.pin_memory())
rusty1s's avatar
rusty1s committed
140
141

    def is_pinned(self):
rusty1s's avatar
rusty1s committed
142
        return all([x.is_pinned for x in self.__attributes])
rusty1s's avatar
rusty1s committed
143
144

    def share_memory_(self):
rusty1s's avatar
rusty1s committed
145
        return self.__apply_(lambda x: x.share_memory_())
rusty1s's avatar
rusty1s committed
146
147

    def is_shared(self):
rusty1s's avatar
rusty1s committed
148
        return all([x.is_shared for x in self.__attributes])
rusty1s's avatar
rusty1s committed
149
150
151
152
153
154

    @property
    def device(self):
        return self.__row.device

    def cpu(self):
rusty1s's avatar
rusty1s committed
155
        return self.__apply(lambda x: x.cpu())
rusty1s's avatar
rusty1s committed
156

rusty1s's avatar
rusty1s committed
157
158
    def cuda(self, device=None, non_blocking=False, **kwargs):
        return self.__apply(lambda x: x.cuda(device, non_blocking, **kwargs))
rusty1s's avatar
rusty1s committed
159
160
161

    @property
    def is_cuda(self):
rusty1s's avatar
rusty1s committed
162
        return self.__row.is_cuda
rusty1s's avatar
rusty1s committed
163
164
165

    @property
    def dtype(self):
rusty1s's avatar
rusty1s committed
166
167
168
169
170
171
172
173
174
175
        return None if self.__value is None else self.__value.dtype

    def to(self, *args, **kwargs):
        if 'device' in kwargs:
            out = self.__apply(lambda x: x.to(kwargs['device']))
            del kwargs['device']
        for arg in args[:]:
            if isinstance(arg, str) or isinstance(arg, torch.device):
                out = self.__apply(lambda x: x.to(arg))
                args.remove(arg)
rusty1s's avatar
rusty1s committed
176

rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
184
        if len(args) > 0 and len(kwargs) > 0:
            out = self.type(*args, **kwargs)

        return out

    def type(self, dtype=None, non_blocking=False, **kwargs):
        return self.dtype if dtype is None else self.__apply_value(
            lambda x: x.type(dtype, non_blocking, **kwargs))
rusty1s's avatar
rusty1s committed
185
186

    def is_floating_point(self):
rusty1s's avatar
rusty1s committed
187
        return self.__value is None or torch.is_floating_point(self.__value)
rusty1s's avatar
rusty1s committed
188
189

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
190
        return self.__apply_value(lambda x: x.bfloat16())
rusty1s's avatar
rusty1s committed
191
192

    def bool(self):
rusty1s's avatar
rusty1s committed
193
        return self.__apply_value(lambda x: x.bool())
rusty1s's avatar
rusty1s committed
194
195

    def byte(self):
rusty1s's avatar
rusty1s committed
196
        return self.__apply_value(lambda x: x.byte())
rusty1s's avatar
rusty1s committed
197
198

    def char(self):
rusty1s's avatar
rusty1s committed
199
        return self.__apply_value(lambda x: x.char())
rusty1s's avatar
rusty1s committed
200
201

    def half(self):
rusty1s's avatar
rusty1s committed
202
        return self.__apply_value(lambda x: x.half())
rusty1s's avatar
rusty1s committed
203
204

    def float(self):
rusty1s's avatar
rusty1s committed
205
        return self.__apply_value(lambda x: x.float())
rusty1s's avatar
rusty1s committed
206
207

    def double(self):
rusty1s's avatar
rusty1s committed
208
        return self.__apply_value(lambda x: x.double())
rusty1s's avatar
rusty1s committed
209
210

    def short(self):
rusty1s's avatar
rusty1s committed
211
        return self.__apply_value(lambda x: x.short())
rusty1s's avatar
rusty1s committed
212
213

    def int(self):
rusty1s's avatar
rusty1s committed
214
        return self.__apply_value(lambda x: x.int())
rusty1s's avatar
rusty1s committed
215
216

    def long(self):
rusty1s's avatar
rusty1s committed
217
218
219
        return self.__apply_value(lambda x: x.long())

    ###########################################################################
rusty1s's avatar
rusty1s committed
220

rusty1s's avatar
rusty1s committed
221
222
    def __keys(self):
        return inspect.getfullargspec(self.__init__)[0][1:-1]
rusty1s's avatar
rusty1s committed
223

rusty1s's avatar
rusty1s committed
224
225
226
227
228
    def __state(self):
        return {
            key: getattr(self, f'_{self.__class__.__name__}__{key}')
            for key in self.__keys()
        }
rusty1s's avatar
rusty1s committed
229
230

    def __apply_value(self, func):
rusty1s's avatar
rusty1s committed
231
232
233
        state = self.__state()
        state['value'] == func(self.__value)
        return self.__class__(is_sorted=True, **state)
rusty1s's avatar
rusty1s committed
234
235

    def __apply_value_(self, func):
rusty1s's avatar
rusty1s committed
236
237
        self.__value = None if self.__value is None else func(self.__value)
        return self
rusty1s's avatar
rusty1s committed
238
239

    def __apply(self, func):
rusty1s's avatar
rusty1s committed
240
241
        state = {key: func(item) for key, item in self.__state().items()}
        return self.__class__(is_sorted=True, **state)
rusty1s's avatar
rusty1s committed
242
243

    def __apply_(self, func):
rusty1s's avatar
rusty1s committed
244
245
246
247
248
        state = self.__state()
        del state['value']
        for key, item in self.__state().items():
            setattr(self, f'_{self.__class__.__name__}__{key}', func(item))
        return self.__apply_value_(func)
rusty1s's avatar
rusty1s committed
249
250
251
252
253
254
255
256
257
258
259
260


if __name__ == '__main__':
    from torch_geometric.datasets import Reddit  # noqa
    import time  # noqa

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset = Reddit('/tmp/Reddit')
    data = dataset[0].to(device)
    edge_index = data.edge_index
    row, col = edge_index
rusty1s's avatar
sorting  
rusty1s committed
261
262

    storage = SparseStorage(row, col)