Commit af9127e3 authored by rusty1s's avatar rusty1s
Browse files

fix permute

parent 904b1d48
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_permute(device):
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
value = tensor([1, 2, 3, 4, 5], torch.float, device)
adj = SparseTensor(row=row, col=col, value=value)
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
assert row.tolist() == [0, 1, 1, 2, 2]
assert col.tolist() == [1, 0, 1, 0, 2]
print(value)
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_symmetric()
row, col, value = src.coo()
row = perm[row]
col = perm[col]
if value is not None:
value = value[row]
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount[perm]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount[perm]
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return src.from_storage(storage)
assert src.is_quadratic()
return src.index_select(0, perm).index_select(1, perm)
SparseTensor.permute = lambda self, perm: permute(self, perm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment