cat.py 3.4 KB
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch


def cat(tensors, dim):
    assert len(tensors) > 0
    has_value = tensors[0].has_value()
    has_rowcount = tensors[0].storage._rowcount is not None
    has_rowptr = tensors[0].storage._rowptr is not None
    has_colcount = tensors[0].storage._colcount is not None
    has_colptr = tensors[0].storage._colptr is not None
    has_csr2csc = tensors[0].storage._csr2csc is not None
    has_csc2csr = tensors[0].storage._csc2csr is not None

    rows, cols, values, sparse_size = [], [], [], [0, 0]
    rowcounts, rowptrs, colcounts, colptrs = [], [], [], []
    csr2cscs, csc2csrs, nnzs = [], [], 0

    if dim == 0:
        for tensor in tensors:
            (row, col), value = tensor.coo()
            rows += [row + sparse_size[0]]
            cols += [col]
            values += [value] if has_value else []
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))

            rowcounts += [tensor.storage.rowcount] if has_rowcount else []

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
rusty1s's avatar
fixes  
rusty1s committed
41
            rowptr=torch.cat(rowptrs) if has_rowptr else None, is_sorted=True)
rusty1s's avatar
cat  
rusty1s committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    if dim == 1:
        raise NotImplementedError

    if dim == (0, 1) or (1, 0):
        for tensor in tensors:
            (row, col), value = tensor.coo()
            rows += [row + sparse_size[0]]
            cols += [col + sparse_size[1]]
            values += [value] if has_value else []
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] += tensor.sparse_size(1)

            rowcounts += [tensor.storage.rowcount] if has_rowcount else []
            colcounts += [tensor.storage.colcount] if has_colcount else []

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

            if has_colptr:
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

            csr2cscs += [tensor.storage.csr2csc + nnzs] if has_csr2csc else []
            csc2csrs += [tensor.storage.csc2csr + nnzs] if has_csc2csr else []

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
            rowptr=torch.cat(rowptrs) if has_rowptr else None,
            colcount=torch.cat(colcounts) if has_colcount else None,
            colptr=torch.cat(colptrs) if has_colptr else None,
            csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
            csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
            is_sorted=True)

    else:
        raise NotImplementedError

    return tensors[0].__class__.from_storage(storage)