transpose.py 1.87 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
@torch.jit.script
def t(src: SparseTensor):
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
20
    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
rusty1s's avatar
rusty1s committed
21
        rowptr=src.storage._colptr,
rusty1s's avatar
rusty1s committed
22
23
24
        col=row[csr2csc],
        value=value,
        sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rusty1s's avatar
rusty1s committed
25
26
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
rusty1s's avatar
rusty1s committed
27
        colcount=src.storage._rowcount,
rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
34
35
36


SparseTensor.t = lambda self: t(self)
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

###############################################################################


def transpose(index, value, m, n, coalesced=True):
    """Transposes dimensions 0 and 1 of a sparse tensor.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
            the output. (default: :obj:`True`)
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    if value.dim() == 1 and not value.is_cuda:
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value

    row, col = index
    index = torch.stack([col, row], dim=0)
    if coalesced:
        index, value = coalesce(index, value, n, m)
    return index, value