transpose.py 1.33 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_sparse import to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3
4
5


def transpose(index, value, m, n):
rusty1s's avatar
rusty1s committed
6
    """Transposes dimensions 0 and 1 of a sparse tensor.
rusty1s's avatar
docs  
rusty1s committed
7
8
9
10
11
12
13
14
15

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
    row, col = index
    index = torch.stack([col, row], dim=0)
rusty1s's avatar
rusty1s committed
19
    index, value = coalesce(index, value, n, m)
rusty1s's avatar
rusty1s committed
20
    return index, value
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def transpose_matrix(index, value, m, n):
    """Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
    one-dimensional.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    assert value.dim() == 1

rusty1s's avatar
rusty1s committed
38
39
40
    if index.is_cuda:
        return transpose(index, value, m, n)
    else:
rusty1s's avatar
rusty1s committed
41
42
43
44
        mat = to_scipy(index, value, m, n).tocsc()
        (col, row), value = from_scipy(mat)
        index = torch.stack([row, col], dim=0)
        return index, value