transpose.py 1.56 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3
4


rusty1s's avatar
linting  
rusty1s committed
5
def transpose(index, value, m, n, coalesced=True):
rusty1s's avatar
rusty1s committed
6
    """Transposes dimensions 0 and 1 of a sparse tensor.
rusty1s's avatar
docs  
rusty1s committed
7
8
9
10

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
11
12
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
linting  
rusty1s committed
13
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
rusty1s's avatar
typo  
rusty1s committed
14
            the output. (default: :obj:`True`)
rusty1s's avatar
docs  
rusty1s committed
15
16
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
17

18
    if value.dim() == 1 and not value.is_cuda:
rusty1s's avatar
rusty1s committed
19
20
21
22
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value
23
24
25

    row, col = index
    index = torch.stack([col, row], dim=0)
ekka's avatar
ekka committed
26
    if coalesced:
27
        index, value = coalesce(index, value, n, m)
28
    return index, value
rusty1s's avatar
rusty1s committed
29
30


rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def t(src):
    (row, col), value = src.coo()
    csr2csc = src.storage.csr2csc

    storage = src.storage.__class__(
        index=torch.stack([col, row], dim=0)[:, csr2csc],
        value=value[csr2csc] if src.has_value() else None,
        sparse_size=src.sparse_size()[::-1],
        rowcount=src.storage._colcount,
        rowptr=src.storage._colptr,
        colcount=src.storage._rowcount,
        colptr=src.storage._rowptr,
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)