"vscode:/vscode.git/clone" did not exist on "d6cf1442c0c95c8c2b4283f621b25215c116caff"
sparse.py 5.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import inspect
rusty1s's avatar
rusty1s committed
2
3
4
from textwrap import indent
import torch

rusty1s's avatar
rusty1s committed
5
6
7
8
9
from torch_sparse.storage import SparseStorage

methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone']

rusty1s's avatar
rusty1s committed
10
11
12
13

class SparseTensor(object):
    def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
        assert index.dim() == 2 and index.size(0) == 2
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        self._storage = SparseStorage(index[0], index[1], value, sparse_size,
                                      is_sorted=is_sorted)

    @classmethod
    def from_storage(self, storage):
        self = SparseTensor.__new__(SparseTensor)
        self._storage = storage
        return self

    @classmethod
    def from_dense(self, mat):
        if mat.dim() > 2:
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
        else:
            index = mat.nonzero()

        index = index.t().contiguous()
        value = mat[index[0], index[1]]
        return SparseTensor(index, value, mat.size()[:2], is_sorted=True)
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
36
    @property
    def _storage(self):
        return self.__storage
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
40
41
42
    @_storage.setter
    def _storage(self, storage):
        self.__storage = storage
        for name in methods:
            setattr(self, name, getattr(storage, name))
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def clone(self):
        return SparseTensor.from_storage(self._storage.clone())

    def __copy__(self):
        return self.clone()

    def __deepcopy__(self, memo):
        memo = memo.setdefault('SparseStorage', {})
        if self._cdata in memo:
            return memo[self._cdata]
        new_sparse_tensor = self.clone()
        memo[self._cdata] = new_sparse_tensor
        return new_sparse_tensor

    def coo(self):
        return self._index, self._value

    def csr(self):
rusty1s's avatar
rusty1s committed
62
        return self._rowptr, self._col, self._value
rusty1s's avatar
rusty1s committed
63
64
65

    def csc(self):
        perm = self._arg_csr_to_csc
rusty1s's avatar
rusty1s committed
66
        return self._colptr, self._row[perm], self._value[perm]
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    def is_quadratic(self):
        return self.sparse_size[0] == self.sparse_size[1]

    def is_symmetric(self):
        if not self.is_quadratic:
            return False

        index1, value1 = self.coo()
        index2, value2 = self.t().coo()
        index_symmetric = (index1 == index2).all()
        value_symmetric = (value1 == value2).all() if self.has_value else True
        return index_symmetric and value_symmetric

    def set_value(self, value, layout):
        if value is not None and layout == 'csc':
            value = value[self._arg_csc_to_csr]
        return self._apply_value(value)

    def set_value_(self, value, layout):
        if value is not None and layout == 'csc':
            value = value[self._arg_csc_to_csr]
        return self._apply_value_(value)

    def t(self):
        storage = SparseStorage(
            self._col[self._arg_csr_to_csc],
            self._row[self._arg_csr_to_csc],
            self._value[self._arg_csr_to_csc] if self.has_value else None,
            self.sparse_size()[::-1],
            self._colptr,
            self._rowptr,
            self._arg_csc_to_csr,
            self._arg_csr_to_csc,
            is_sorted=True,
        )
        return self.__class__.from_storage(storage)
rusty1s's avatar
rusty1s committed
104

rusty1s's avatar
rusty1s committed
105
    def matmul(self, mat2):
rusty1s's avatar
rusty1s committed
106
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
107

rusty1s's avatar
rusty1s committed
108
    def coalesce(self, reduce='add'):
rusty1s's avatar
rusty1s committed
109
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
    def is_coalesced(self):
rusty1s's avatar
rusty1s committed
112
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
115
116
117
    def add(self, layout=None):
        # sub, mul, div
        # can take scalars, tensors and other sparse matrices
        # inplace variants can only take scalars or tensors
rusty1s's avatar
rusty1s committed
118
119
120
        raise NotImplementedError

    # TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
rusty1s committed
122
123
124
    def to_dense(self, dtype=None):
        dtype = dtype or self.dtype
        mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
rusty1s's avatar
rusty1s committed
125
        mat[self._row, self._col] = self._value if self.has_value else 1
rusty1s's avatar
rusty1s committed
126
        return mat
rusty1s's avatar
rusty1s committed
127

rusty1s's avatar
rusty1s committed
128
129
    def to_scipy(self):
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
132
    def to_torch_sparse_coo_tensor(self):
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
133
134
135

    def __repr__(self):
        i = ' ' * 6
rusty1s's avatar
rusty1s committed
136
137
138
139
140
141
142
143
144
        index, value = self.coo()
        infos = [f'index={indent(index.__repr__(), i)[len(i):]}']
        if value is not None:
            infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
        infos += [
            f'size={tuple(self.size())}, '
            f'nnz={self.nnz()}, '
            f'density={100 * self.density():.02f}%'
        ]
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
        infos = ',\n'.join(infos)

        i = ' ' * (len(self.__class__.__name__) + 1)
        return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'


if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
152
153
    from torch_geometric.datasets import Reddit, Planetoid  # noqa
    import time  # noqa
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'

    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/PubMed', 'PubMed')
    data = dataset[0].to(device)

    _bytes = data.edge_index.numel() * 8
    _kbytes = _bytes / 1024
    _mbytes = _kbytes / 1024
    _gbytes = _mbytes / 1024
    print(f'Storage: {_gbytes:.04f} GB')

    mat1 = SparseTensor(data.edge_index)
rusty1s's avatar
rusty1s committed
169
    print(mat1)
rusty1s's avatar
rusty1s committed
170
171
172
173
174
175
    mat1 = mat1.t()

    mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
                                   device=device)
    mat2 = mat2.coalesce()
    mat2 = mat2.t().coalesce()
rusty1s's avatar
rusty1s committed
176

rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
184
185
186
    index1, value1 = mat1.coo()
    index2, value2 = mat2._indices(), mat2._values()
    assert torch.allclose(index1, index2)

    out1 = mat1.to_dense()
    out2 = mat2.to_dense()
    assert torch.allclose(out1, out2)

    mat1 = SparseTensor.from_dense(out1)
    print(mat1)