transpose.py 1.79 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
4
5
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor

rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
def t(src: SparseTensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
rusty1s's avatar
rusty1s committed
19
        rowptr=src.storage._colptr,
rusty1s's avatar
rusty1s committed
20
21
        col=row[csr2csc],
        value=value,
rusty1s's avatar
rusty1s committed
22
        sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rusty1s's avatar
rusty1s committed
23
24
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
rusty1s's avatar
rusty1s committed
25
        colcount=src.storage._rowcount,
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
32
33
34


SparseTensor.t = lambda self: t(self)
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44

###############################################################################


def transpose(index, value, m, n, coalesced=True):
    """Transposes dimensions 0 and 1 of a sparse tensor.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
wang-ps's avatar
wang-ps committed
45
46
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
            the output. (default: :obj:`True`)
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    row, col = index
rusty1s's avatar
rusty1s committed
53
54
    row, col = col, row

rusty1s's avatar
rusty1s committed
55
    if coalesced:
rusty1s's avatar
rusty1s committed
56
        sparse_sizes = (n, m)
rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
        storage = SparseStorage(row=row, col=col, value=value,
                                sparse_sizes=sparse_sizes, is_sorted=False)
        storage = storage.coalesce()
        row, col, value = storage.row(), storage.col(), storage.value()

    return torch.stack([row, col], dim=0), value