sparse.py 3.02 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from textwrap import indent
import torch


class SparseTensor(object):
    def __init__(self, index, value=None, sparse_size=None, is_sorted=False):

        assert index.dtype == torch.long
        assert index.dim() == 2 and index.size(0) == 2
        index = index.contiguous()

        if value is not None:
            assert value.size(0) == index.size(1)
            assert index.device == value.device
            value = value.contiguous()

        if sparse_size is None:
            sparse_size = torch.Size((index.max(dim=-1)[0].cpu() + 1).tolist())

        self.__index__ = index
        self.__value__ = value
        self.__sparse_size__ = sparse_size

        if not is_sorted and not self.__is_sorted__():
            self.__sort__()

    def to(*args, **kwargs):
        # TODO
        pass

    def size(self, dim=None):
        size = self.__sparse_size__
        size += () if self.__value__ is None else self.__value__.size()[1:]
        return size if dim is None else size[dim]

    def storage(self):
        pass

    @property
    def shape(self):
        return self.size()

    def dim(self):
        return len(self.size())

    @property
    def dtype(self):
        return None if self.__value__ is None else self.__value__.dtype

    @property
    def device(self):
        return self.__index__.device

    def nnz(self):
        return self.__index__.size(1)

    def numel(self):
        return self.__value__.numel() if self.__value__ else self.nnz()

    def clone(self):
        return self.__class__(
            index=self.__index__.clone(),
            value=None if self.__value__ is None else self.__value__.clone(),
            sparse_size=self.__sparse_size__,
            is_sorted=True,
        )

    def sparse_resize_(self, *sizes):
        assert len(sizes) == 2
        self.__sparse_size__ = torch.Size(sizes)

    def __is_sorted__(self):
        idx1 = self.size(1) * index[0] + index[1]
        idx2 = torch.cat([idx1.new_zeros(1), idx1[:-1]], dim=0)
        return (idx1 >= idx2).all().item()

    def __sort__(self):
        idx = self.__sparse_size__(1) * self.__index__[0] + self.__index__[1]
        perm = idx.argsort()
        self.__index__ = index[:, perm]
        self.__value__ = None if self.__value__ is None else self.__value__[
            perm]

    def __repr__(self):
        i = ' ' * 6
        infos = [f'index={indent(self.__index__.__repr__(), i)[len(i):]}']
        if self.__value__ is not None:
            infos += [f'value={indent(self.__value__.__repr__(), i)[len(i):]}']
        infos += [f'size={tuple(self.size())}, nnz={self.nnz()}']
        infos = ',\n'.join(infos)

        i = ' ' * (len(self.__class__.__name__) + 1)
        return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'


if __name__ == '__main__':
    index = torch.tensor([
        [0, 0, 1, 1, 2, 2],
        [2, 1, 2, 3, 0, 1],
    ])
    value = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)

    mat1 = SparseTensor(index, value)
    print(mat1)

    mat2 = torch.sparse_coo_tensor(index, value)
    # print(mat2)