transpose.py 1.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3
4


rusty1s's avatar
linting  
rusty1s committed
5
def transpose(index, value, m, n, coalesced=True):
rusty1s's avatar
rusty1s committed
6
    """Transposes dimensions 0 and 1 of a sparse tensor.
rusty1s's avatar
docs  
rusty1s committed
7
8
9
10

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
11
12
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
linting  
rusty1s committed
13
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
rusty1s's avatar
typo  
rusty1s committed
14
            the output. (default: :obj:`True`)
rusty1s's avatar
docs  
rusty1s committed
15
16
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
17

18
    if value.dim() == 1 and not value.is_cuda:
rusty1s's avatar
rusty1s committed
19
20
21
22
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value
23
24
25

    row, col = index
    index = torch.stack([col, row], dim=0)
ekka's avatar
ekka committed
26
    if coalesced:
27
        index, value = coalesce(index, value, n, m)
28
    return index, value
rusty1s's avatar
rusty1s committed
29
30
31


def t(mat):
rusty1s's avatar
rusty1s committed
32
33
34
    (row, col), value = mat.coo()
    csr_to_csc = mat._storage.csr_to_csc

rusty1s's avatar
rusty1s committed
35
    storage = mat._storage.__class__(
rusty1s's avatar
rusty1s committed
36
37
        index=torch.stack([col, row], dim=0)[:, csr_to_csc],
        value=value[csr_to_csc] if mat.has_value() else None,
rusty1s's avatar
rusty1s committed
38
39
40
41
        sparse_size=mat.sparse_size()[::-1],
        rowptr=mat._storage._colptr,
        colptr=mat._storage._rowptr,
        csr_to_csc=mat._storage._csc_to_csr,
rusty1s's avatar
rusty1s committed
42
        csc_to_csr=csr_to_csc,
rusty1s's avatar
rusty1s committed
43
        is_sorted=True)
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
    return mat.__class__.from_storage(storage)