Commit 57852a66 authored by rusty1s's avatar rusty1s
Browse files

bandwidth implementation

parent 539e2068
...@@ -47,6 +47,7 @@ from .matmul import matmul # noqa ...@@ -47,6 +47,7 @@ from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa
from .saint import saint_subgraph # noqa from .saint import saint_subgraph # noqa
from .padding import padded_index, padded_index_select # noqa from .padding import padded_index, padded_index_select # noqa
...@@ -90,6 +91,7 @@ __all__ = [ ...@@ -90,6 +91,7 @@ __all__ = [
'cat_diag', 'cat_diag',
'random_walk', 'random_walk',
'partition', 'partition',
'reverse_cuthill_mckee',
'saint_subgraph', 'saint_subgraph',
'padded_index', 'padded_index',
'padded_index_select', 'padded_index_select',
......
import scipy.sparse as sp
from typing import Tuple, Optional
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def reverse_cuthill_mckee(src: SparseTensor,
is_symmetric: Optional[bool] = None
) -> Tuple[SparseTensor, torch.Tensor]:
if is_symmetric is None:
is_symmetric = src.is_symmetric()
if not is_symmetric:
src = src.to_symmetric()
sp_src = src.to_scipy(layout='csr')
perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy()
perm = torch.from_numpy(perm).to(torch.long).to(src.device())
out = permute(src, perm)
return out, perm
SparseTensor.reverse_cuthill_mckee = reverse_cuthill_mckee
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment