cat.py 5.27 KB
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
6
import torch


def cat(tensors, dim):
    assert len(tensors) > 0
    has_value = tensors[0].has_value()
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
    has_rowcount = tensors[0].storage.has_rowcount()
    has_rowptr = tensors[0].storage.has_rowptr()
    has_colcount = tensors[0].storage.has_colcount()
    has_colptr = tensors[0].storage.has_colptr()
    has_csr2csc = tensors[0].storage.has_csr2csc()
    has_csc2csr = tensors[0].storage.has_csc2csr()
rusty1s's avatar
cat  
rusty1s committed
13
14
15
16
17

    rows, cols, values, sparse_size = [], [], [], [0, 0]
    rowcounts, rowptrs, colcounts, colptrs = [], [], [], []
    csr2cscs, csc2csrs, nnzs = [], [], 0

rusty1s's avatar
rusty1s committed
18
19
20
21
22
    if isinstance(dim, int):
        dim = tensors[0].dim() + dim if dim < 0 else dim
    else:
        dim = tuple([tensors[0].dim() + d if d < 0 else d for d in dim])

rusty1s's avatar
cat  
rusty1s committed
23
24
    if dim == 0:
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
25
            row, col, value = tensor.coo()
rusty1s's avatar
cat  
rusty1s committed
26
27
            rows += [row + sparse_size[0]]
            cols += [col]
rusty1s's avatar
rusty1s committed
28
            values += [value]
rusty1s's avatar
cat  
rusty1s committed
29
30
31
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))

rusty1s's avatar
rusty1s committed
32
33
            if has_rowcount:
                rowcounts += [tensor.storage.rowcount]
rusty1s's avatar
cat  
rusty1s committed
34
35
36
37
38
39
40
41
42
43
44
45
46

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
rusty1s's avatar
fixes  
rusty1s committed
47
            rowptr=torch.cat(rowptrs) if has_rowptr else None, is_sorted=True)
rusty1s's avatar
cat  
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
50
    elif dim == 1:
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
51
            row, col, value = tensor.coo()
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
            rows += [row]
            cols += [col + sparse_size[1]]
            values += [value]
            sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
            sparse_size[1] += tensor.sparse_size(1)

            if has_colcount:
                colcounts += [tensor.storage.colcount]

rusty1s's avatar
typo  
rusty1s committed
61
            if has_colptr:
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
72
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            colcount=torch.cat(colcounts) if has_colcount else None,
rusty1s's avatar
rusty1s committed
73
74
75
            colptr=torch.cat(colptrs) if has_colptr else None,
            is_sorted=False,
        )
rusty1s's avatar
cat  
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
    elif dim == (0, 1) or dim == (1, 0):
rusty1s's avatar
cat  
rusty1s committed
78
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
79
            row, col, value = tensor.coo()
rusty1s's avatar
cat  
rusty1s committed
80
81
82
83
84
85
            rows += [row + sparse_size[0]]
            cols += [col + sparse_size[1]]
            values += [value] if has_value else []
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] += tensor.sparse_size(1)

rusty1s's avatar
rusty1s committed
86
87
            if has_rowcount:
                rowcounts += [tensor.storage.rowcount]
rusty1s's avatar
cat  
rusty1s committed
88
89
90
91
92
93

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

rusty1s's avatar
rusty1s committed
94
95
96
            if has_colcount:
                colcounts += [tensor.storage.colcount]

rusty1s's avatar
cat  
rusty1s committed
97
98
99
100
101
            if has_colptr:
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

rusty1s's avatar
rusty1s committed
102
103
104
105
106
            if has_csr2csc:
                csr2cscs += [tensor.storage.csr2csc + nnzs]

            if has_csc2csr:
                csc2csrs += [tensor.storage.csc2csr + nnzs]
rusty1s's avatar
cat  
rusty1s committed
107
108
109
110
111
112
113
114
115
116
117
118
119

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
            rowptr=torch.cat(rowptrs) if has_rowptr else None,
            colcount=torch.cat(colcounts) if has_colcount else None,
            colptr=torch.cat(colptrs) if has_colptr else None,
            csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
            csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
rusty1s's avatar
rusty1s committed
120
121
            is_sorted=True,
        )
rusty1s's avatar
cat  
rusty1s committed
122

rusty1s's avatar
rusty1s committed
123
124
125
126
127
    elif isinstance(dim, int) and dim > 1 and dim < tensors[0].dim():
        for tensor in tensors:
            values += [tensor.storage.value]

        old_storage = tensors[0].storage
rusty1s's avatar
rusty1s committed
128
        storage = old_storage.__class__(
rusty1s's avatar
cat fix  
rusty1s committed
129
130
131
            row=old_storage._row,
            rowptr=old_storage._rowptr,
            col=old_storage._col,
rusty1s's avatar
rusty1s committed
132
133
134
            value=torch.cat(values, dim=dim - 1),
            sparse_size=old_storage.sparse_size(),
            colptr=old_storage._colptr,
rusty1s's avatar
cat fix  
rusty1s committed
135
            colcount=old_storage._colcount,
rusty1s's avatar
rusty1s committed
136
137
138
139
            csr2csc=old_storage._csr2csc,
            csc2csr=old_storage._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
140

rusty1s's avatar
cat  
rusty1s committed
141
    else:
rusty1s's avatar
rusty1s committed
142
143
144
        raise IndexError(
            (f'Dimension out of range: Expected to be in range of '
             f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]'))
rusty1s's avatar
cat  
rusty1s committed
145
146

    return tensors[0].__class__.from_storage(storage)