transpose.py 602 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch
from torch_sparse import coalesce


def transpose(index, value, m, n):
rusty1s's avatar
docs  
rusty1s committed
6
7
8
9
10
11
12
13
14
15
    """Transposes dimensions 0 and 1 of a sparse matrix.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of sparse matrix.
        n (int): The second dimension of sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
    row, col = index
    index = torch.stack([col, row], dim=0)

rusty1s's avatar
rusty1s committed
20
    index, value = coalesce(index, value, n, m)
rusty1s's avatar
rusty1s committed
21
22

    return index, value