transpose.py 1.34 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3
4
5


def transpose(index, value, m, n):
rusty1s's avatar
rusty1s committed
6
    """Transposes dimensions 0 and 1 of a sparse tensor.
rusty1s's avatar
docs  
rusty1s committed
7
8
9
10
11
12
13
14
15

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
    row, col = index
    index = torch.stack([col, row], dim=0)
rusty1s's avatar
rusty1s committed
19
    index, value = coalesce(index, value, n, m)
rusty1s's avatar
rusty1s committed
20
    return index, value
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


def transpose_matrix(index, value, m, n):
    """Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
    one-dimensional.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    assert value.dim() == 1

    if not index.is_cuda:
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value
    else:
        return transpose(index, value, m, n)