cat.py 5.2 KB
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
import torch


def cat(tensors, dim):
    assert len(tensors) > 0
rusty1s's avatar
rusty1s committed
6
    has_row = tensors[0].storage.has_row()
rusty1s's avatar
cat  
rusty1s committed
7
    has_value = tensors[0].has_value()
rusty1s's avatar
rusty1s committed
8
9
    has_rowcount = tensors[0].storage.has_rowcount()
    has_colptr = tensors[0].storage.has_colptr()
rusty1s's avatar
rusty1s committed
10
    has_colcount = tensors[0].storage.has_colcount()
rusty1s's avatar
rusty1s committed
11
12
    has_csr2csc = tensors[0].storage.has_csr2csc()
    has_csc2csr = tensors[0].storage.has_csc2csr()
rusty1s's avatar
cat  
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
15
    rows, rowptrs, cols, values, sparse_size, nnzs = [], [], [], [], [0, 0], 0
    rowcounts, colcounts, colptrs, csr2cscs, csc2csrs = [], [], [], [], []
rusty1s's avatar
cat  
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
20
21
    if isinstance(dim, int):
        dim = tensors[0].dim() + dim if dim < 0 else dim
    else:
        dim = tuple([tensors[0].dim() + d if d < 0 else d for d in dim])

rusty1s's avatar
cat  
rusty1s committed
22
23
    if dim == 0:
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
24
25
26
            rowptr, col, value = tensor.csr()
            rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
            rowptrs += [rowptr + nnzs]
rusty1s's avatar
cat  
rusty1s committed
27
            cols += [col]
rusty1s's avatar
rusty1s committed
28
            values += [value]
rusty1s's avatar
rusty1s committed
29
30
31

            if has_row:
                rows += [tensor.storage.row + sparse_size[0]]
rusty1s's avatar
cat  
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
34
            if has_rowcount:
                rowcounts += [tensor.storage.rowcount]
rusty1s's avatar
cat  
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
rusty1s's avatar
cat  
rusty1s committed
38
39
40
            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
rusty1s's avatar
rusty1s committed
41
42
            row=torch.cat(rows) if has_row else None,
            rowptr=torch.cat(rowptrs), col=torch.cat(cols),
rusty1s's avatar
cat  
rusty1s committed
43
44
45
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
rusty1s's avatar
rusty1s committed
46
            is_sorted=True)
rusty1s's avatar
cat  
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
    elif dim == 1:
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
50
            row, col, value = tensor.coo()
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
            rows += [row]
            cols += [col + sparse_size[1]]
            values += [value]

            if has_colcount:
                colcounts += [tensor.storage.colcount]

rusty1s's avatar
typo  
rusty1s committed
58
            if has_colptr:
rusty1s's avatar
rusty1s committed
59
60
61
62
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

rusty1s's avatar
rusty1s committed
63
64
            sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
            sparse_size[1] += tensor.sparse_size(1)
rusty1s's avatar
rusty1s committed
65
66
67
            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
rusty1s's avatar
rusty1s committed
68
69
            row=torch.cat(rows),
            col=torch.cat(cols),
rusty1s's avatar
rusty1s committed
70
71
72
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            colcount=torch.cat(colcounts) if has_colcount else None,
rusty1s's avatar
rusty1s committed
73
74
75
            colptr=torch.cat(colptrs) if has_colptr else None,
            is_sorted=False,
        )
rusty1s's avatar
cat  
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
    elif dim == (0, 1) or dim == (1, 0):
rusty1s's avatar
cat  
rusty1s committed
78
        for tensor in tensors:
rusty1s's avatar
rusty1s committed
79
80
81
            rowptr, col, value = tensor.csr()
            rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
            rowptrs += [rowptr + nnzs]
rusty1s's avatar
cat  
rusty1s committed
82
            cols += [col + sparse_size[1]]
rusty1s's avatar
rusty1s committed
83
84
85
86
            values += [value]

            if has_row:
                rows += [tensor.storage.row + sparse_size[0]]
rusty1s's avatar
cat  
rusty1s committed
87

rusty1s's avatar
rusty1s committed
88
89
            if has_rowcount:
                rowcounts += [tensor.storage.rowcount]
rusty1s's avatar
cat  
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
92
93
            if has_colcount:
                colcounts += [tensor.storage.colcount]

rusty1s's avatar
cat  
rusty1s committed
94
95
96
97
98
            if has_colptr:
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

rusty1s's avatar
rusty1s committed
99
100
101
102
103
            if has_csr2csc:
                csr2cscs += [tensor.storage.csr2csc + nnzs]

            if has_csc2csr:
                csc2csrs += [tensor.storage.csc2csr + nnzs]
rusty1s's avatar
cat  
rusty1s committed
104

rusty1s's avatar
rusty1s committed
105
106
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] += tensor.sparse_size(1)
rusty1s's avatar
cat  
rusty1s committed
107
108
109
            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
rusty1s's avatar
rusty1s committed
110
111
112
            row=torch.cat(rows) if has_row else None,
            rowptr=torch.cat(rowptrs),
            col=torch.cat(cols),
rusty1s's avatar
cat  
rusty1s committed
113
114
115
116
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
            colptr=torch.cat(colptrs) if has_colptr else None,
rusty1s's avatar
rusty1s committed
117
            colcount=torch.cat(colcounts) if has_colcount else None,
rusty1s's avatar
cat  
rusty1s committed
118
119
            csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
            csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
rusty1s's avatar
rusty1s committed
120
121
            is_sorted=True,
        )
rusty1s's avatar
cat  
rusty1s committed
122

rusty1s's avatar
rusty1s committed
123
124
125
126
127
    elif isinstance(dim, int) and dim > 1 and dim < tensors[0].dim():
        for tensor in tensors:
            values += [tensor.storage.value]

        old_storage = tensors[0].storage
rusty1s's avatar
rusty1s committed
128
        storage = old_storage.__class__(
rusty1s's avatar
cat fix  
rusty1s committed
129
130
131
            row=old_storage._row,
            rowptr=old_storage._rowptr,
            col=old_storage._col,
rusty1s's avatar
rusty1s committed
132
            value=torch.cat(values, dim=dim - 1),
rusty1s's avatar
rusty1s committed
133
134
            sparse_size=old_storage.sparse_size,
            rowcount=old_storage._rowcount,
rusty1s's avatar
rusty1s committed
135
            colptr=old_storage._colptr,
rusty1s's avatar
cat fix  
rusty1s committed
136
            colcount=old_storage._colcount,
rusty1s's avatar
rusty1s committed
137
138
139
140
            csr2csc=old_storage._csr2csc,
            csc2csr=old_storage._csc2csr,
            is_sorted=True,
        )
rusty1s's avatar
rusty1s committed
141

rusty1s's avatar
cat  
rusty1s committed
142
    else:
rusty1s's avatar
rusty1s committed
143
144
145
        raise IndexError(
            (f'Dimension out of range: Expected to be in range of '
             f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]'))
rusty1s's avatar
cat  
rusty1s committed
146
147

    return tensors[0].__class__.from_storage(storage)