sparse.py 5.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import inspect
rusty1s's avatar
rusty1s committed
2
3
4
from textwrap import indent
import torch

rusty1s's avatar
rusty1s committed
5
6
7
8
9
from torch_sparse.storage import SparseStorage

methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone']

rusty1s's avatar
rusty1s committed
10
11
12
13

class SparseTensor(object):
    def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        self._storage = SparseStorage(index[0], index[1], value, sparse_size,
                                      is_sorted=is_sorted)

    @classmethod
    def from_storage(self, storage):
        self = SparseTensor.__new__(SparseTensor)
        self._storage = storage
        return self

    @classmethod
    def from_dense(self, mat):
        if mat.dim() > 2:
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
        else:
            index = mat.nonzero()

        index = index.t().contiguous()
        value = mat[index[0], index[1]]
        return SparseTensor(index, value, mat.size()[:2], is_sorted=True)
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
36
    @property
    def _storage(self):
        return self.__storage
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
40
41
42
    @_storage.setter
    def _storage(self, storage):
        self.__storage = storage
        for name in methods:
            setattr(self, name, getattr(storage, name))
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    def clone(self):
        return SparseTensor.from_storage(self._storage.clone())

    def __copy__(self):
        return self.clone()

    def __deepcopy__(self, memo):
        memo = memo.setdefault('SparseStorage', {})
        if self._cdata in memo:
            return memo[self._cdata]
        new_sparse_tensor = self.clone()
        memo[self._cdata] = new_sparse_tensor
        return new_sparse_tensor

    def coo(self):
        return self._index, self._value

    def csr(self):
        return self._col, self._rowptr, self._value

    def csc(self):
        perm = self._arg_csr_to_csc
        return self._row[perm], self._colptr, self._value[perm]

    def is_quadratic(self):
        return self.sparse_size[0] == self.sparse_size[1]

    def is_symmetric(self):
        if not self.is_quadratic:
            return False

        index1, value1 = self.coo()
        index2, value2 = self.t().coo()
        index_symmetric = (index1 == index2).all()
        value_symmetric = (value1 == value2).all() if self.has_value else True
        return index_symmetric and value_symmetric

    def set_value(self, value, layout):
        if value is not None and layout == 'csc':
            value = value[self._arg_csc_to_csr]
        return self._apply_value(value)

    def set_value_(self, value, layout):
        if value is not None and layout == 'csc':
            value = value[self._arg_csc_to_csr]
        return self._apply_value_(value)

    def t(self):
        storage = SparseStorage(
            self._col[self._arg_csr_to_csc],
            self._row[self._arg_csr_to_csc],
            self._value[self._arg_csr_to_csc] if self.has_value else None,
            self.sparse_size()[::-1],
            self._colptr,
            self._rowptr,
            self._arg_csc_to_csr,
            self._arg_csr_to_csc,
            is_sorted=True,
        )
        return self.__class__.from_storage(storage)
rusty1s's avatar
rusty1s committed
104

rusty1s's avatar
rusty1s committed
105
    def matmul(self, mat2):
rusty1s's avatar
rusty1s committed
106
107
        pass

rusty1s's avatar
rusty1s committed
108
    def coalesce(self, reduce='add'):
rusty1s's avatar
rusty1s committed
109
110
        pass

rusty1s's avatar
rusty1s committed
111
112
    def is_coalesced(self):
        pass
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
115
116
117
118
    def add(self, layout=None):
        # sub, mul, div
        # can take scalars, tensors and other sparse matrices
        # inplace variants can only take scalars or tensors
        pass
rusty1s's avatar
rusty1s committed
119

rusty1s's avatar
rusty1s committed
120
121
122
123
124
    def to_dense(self, dtype=None):
        dtype = dtype or self.dtype
        mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
        mat[self._row, self._col] = self._value or 1
        return mat
rusty1s's avatar
rusty1s committed
125

rusty1s's avatar
rusty1s committed
126
127
    def to_scipy(self):
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
128

rusty1s's avatar
rusty1s committed
129
130
    def to_torch_sparse_coo_tensor(self):
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
131

rusty1s's avatar
rusty1s committed
132
    # TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
rusty1s's avatar
rusty1s committed
133
134
135

    def __repr__(self):
        i = ' ' * 6
rusty1s's avatar
rusty1s committed
136
137
138
139
140
141
142
143
144
        index, value = self.coo()
        infos = [f'index={indent(index.__repr__(), i)[len(i):]}']
        if value is not None:
            infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
        infos += [
            f'size={tuple(self.size())}, '
            f'nnz={self.nnz()}, '
            f'density={100 * self.density():.02f}%'
        ]
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
        infos = ',\n'.join(infos)

        i = ' ' * (len(self.__class__.__name__) + 1)
        return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'


if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
152
153
    from torch_geometric.datasets import Reddit, Planetoid  # noqa
    import time  # noqa
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'

    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/PubMed', 'PubMed')
    data = dataset[0].to(device)

    _bytes = data.edge_index.numel() * 8
    _kbytes = _bytes / 1024
    _mbytes = _kbytes / 1024
    _gbytes = _mbytes / 1024
    print(f'Storage: {_gbytes:.04f} GB')

    mat1 = SparseTensor(data.edge_index)
rusty1s's avatar
rusty1s committed
169
    print(mat1)
rusty1s's avatar
rusty1s committed
170
171
172
173
174
175
    mat1 = mat1.t()

    mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
                                   device=device)
    mat2 = mat2.coalesce()
    mat2 = mat2.t().coalesce()
rusty1s's avatar
rusty1s committed
176

rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
184
185
186
    index1, value1 = mat1.coo()
    index2, value2 = mat2._indices(), mat2._values()
    assert torch.allclose(index1, index2)

    out1 = mat1.to_dense()
    out2 = mat2.to_dense()
    assert torch.allclose(out1, out2)

    mat1 = SparseTensor.from_dense(out1)
    print(mat1)