cat.py 3.41 KB
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch


def cat(tensors, dim):
    assert len(tensors) > 0
    has_value = tensors[0].has_value()
    has_rowcount = tensors[0].storage._rowcount is not None
    has_rowptr = tensors[0].storage._rowptr is not None
    has_colcount = tensors[0].storage._colcount is not None
    has_colptr = tensors[0].storage._colptr is not None
    has_csr2csc = tensors[0].storage._csr2csc is not None
    has_csc2csr = tensors[0].storage._csc2csr is not None

    rows, cols, values, sparse_size = [], [], [], [0, 0]
    rowcounts, rowptrs, colcounts, colptrs = [], [], [], []
    csr2cscs, csc2csrs, nnzs = [], [], 0

    if dim == 0:
        for tensor in tensors:
            (row, col), value = tensor.coo()
            rows += [row + sparse_size[0]]
            cols += [col]
            values += [value] if has_value else []
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))

            rowcounts += [tensor.storage.rowcount] if has_rowcount else []

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
            rowptr=torch.cat(rowptrs) if has_rowptr else None,
            is_sorted=True)

    if dim == 1:
        raise NotImplementedError

    if dim == (0, 1) or (1, 0):
        for tensor in tensors:
            (row, col), value = tensor.coo()
            rows += [row + sparse_size[0]]
            cols += [col + sparse_size[1]]
            values += [value] if has_value else []
            sparse_size[0] += tensor.sparse_size(0)
            sparse_size[1] += tensor.sparse_size(1)

            rowcounts += [tensor.storage.rowcount] if has_rowcount else []
            colcounts += [tensor.storage.colcount] if has_colcount else []

            if has_rowptr:
                rowptr = tensor.storage.rowptr
                rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
                rowptrs += [rowptr + nnzs]

            if has_colptr:
                colptr = tensor.storage.colptr
                colptr = colptr if len(colptrs) == 0 else colptr[1:]
                colptrs += [colptr + nnzs]

            csr2cscs += [tensor.storage.csr2csc + nnzs] if has_csr2csc else []
            csc2csrs += [tensor.storage.csc2csr + nnzs] if has_csc2csr else []

            nnzs += tensor.nnz()

        storage = tensors[0].storage.__class__(
            torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
            value=torch.cat(values, dim=0) if has_value else None,
            sparse_size=sparse_size,
            rowcount=torch.cat(rowcounts) if has_rowcount else None,
            rowptr=torch.cat(rowptrs) if has_rowptr else None,
            colcount=torch.cat(colcounts) if has_colcount else None,
            colptr=torch.cat(colptrs) if has_colptr else None,
            csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
            csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
            is_sorted=True)

    else:
        raise NotImplementedError

    return tensors[0].__class__.from_storage(storage)