transpose.py 1.79 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
linting  
rusty1s committed
8
def transpose(index, value, m, n, coalesced=True):
rusty1s's avatar
rusty1s committed
9
    """Transposes dimensions 0 and 1 of a sparse tensor.
rusty1s's avatar
docs  
rusty1s committed
10
11
12
13

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
14
15
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
linting  
rusty1s committed
16
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
rusty1s's avatar
typo  
rusty1s committed
17
            the output. (default: :obj:`True`)
rusty1s's avatar
docs  
rusty1s committed
18
19
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
20

21
    if value.dim() == 1 and not value.is_cuda:
rusty1s's avatar
rusty1s committed
22
23
24
25
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value
26
27
28

    row, col = index
    index = torch.stack([col, row], dim=0)
ekka's avatar
ekka committed
29
    if coalesced:
30
        index, value = coalesce(index, value, n, m)
31
    return index, value
rusty1s's avatar
rusty1s committed
32
33


rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
@torch.jit.script
def t(src: SparseTensor):
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
45
46
    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
rusty1s's avatar
rusty1s committed
47
        rowptr=src.storage._colptr,
rusty1s's avatar
rusty1s committed
48
49
50
        col=row[csr2csc],
        value=value,
        sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rusty1s's avatar
rusty1s committed
51
52
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
rusty1s's avatar
rusty1s committed
53
        colcount=src.storage._rowcount,
rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
60
61
62


SparseTensor.t = lambda self: t(self)