permute.py 940 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor


@torch.jit.script
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
    assert src.is_symmetric()

    row, col, value = src.coo()

    row = perm[row]
    col = perm[col]

    if value is not None:
        value = value[row]

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount[perm]

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount[perm]

    storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
                            colptr=None, colcount=colcount, csr2csc=None,
                            csc2csr=None, is_sorted=False)
    return src.from_storage(storage)


SparseTensor.permute = lambda self, perm: permute(self, perm)