Commit 38f62d97 authored by rusty1s's avatar rusty1s
Browse files

docstring

parent dedc6950
from typing import Tuple
import time import time
import copy import copy
from typing import Union, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -12,6 +13,9 @@ partition_fn = torch.ops.torch_sparse.partition ...@@ -12,6 +13,9 @@ partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]: log: bool = True) -> Tuple[Tensor, Tensor]:
r"""Computes the METIS partition of a given sparse adjacency matrix
:obj:`adj_t`, returning its "clustered" permutation :obj:`perm` and
corresponding cluster slices :obj:`ptr`."""
if log: if log:
t = time.perf_counter() t = time.perf_counter()
...@@ -34,24 +38,22 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, ...@@ -34,24 +38,22 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
return perm, ptr return perm, ptr
def permute(data: Union[Data, SparseTensor], perm: Tensor, def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
log: bool = True) -> Union[Data, SparseTensor]: r"""Permutes a :obj:`data` object according to a given permutation
:obj:`perm`."""
if log: if log:
t = time.perf_counter() t = time.perf_counter()
print('Permuting data...', end=' ', flush=True) print('Permuting data...', end=' ', flush=True)
if isinstance(data, Data): data = copy.copy(data)
data = copy.copy(data) for key, value in data:
for key, item in data: if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
if isinstance(item, Tensor) and item.size(0) == data.num_nodes: data[key] = value[perm]
data[key] = item[perm] elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
elif isinstance(item, Tensor) and item.size(0) == data.num_edges: raise NotImplementedError
raise NotImplementedError elif isinstance(value, SparseTensor):
elif isinstance(item, SparseTensor): data[key] = value.permute(perm)
data[key] = permute(item, perm, log=False)
else:
data = data.permute(perm)
if log: if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]') print(f'Done! [{time.perf_counter() - t:.2f}s]')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment