Commit 38f62d97 authored by rusty1s's avatar rusty1s
Browse files

docstring

parent dedc6950
from typing import Tuple
import time
import copy
from typing import Union, Tuple
import torch
from torch import Tensor
......@@ -12,6 +13,9 @@ partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]:
r"""Computes the METIS partition of a given sparse adjacency matrix
:obj:`adj_t`, returning its "clustered" permutation :obj:`perm` and
corresponding cluster slices :obj:`ptr`."""
if log:
t = time.perf_counter()
......@@ -34,24 +38,22 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
return perm, ptr
def permute(data: Union[Data, SparseTensor], perm: Tensor,
log: bool = True) -> Union[Data, SparseTensor]:
def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
r"""Permutes a :obj:`data` object according to a given permutation
:obj:`perm`."""
if log:
t = time.perf_counter()
print('Permuting data...', end=' ', flush=True)
if isinstance(data, Data):
data = copy.copy(data)
for key, item in data:
if isinstance(item, Tensor) and item.size(0) == data.num_nodes:
data[key] = item[perm]
elif isinstance(item, Tensor) and item.size(0) == data.num_edges:
raise NotImplementedError
elif isinstance(item, SparseTensor):
data[key] = permute(item, perm, log=False)
else:
data = data.permute(perm)
data = copy.copy(data)
for key, value in data:
if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
data[key] = value[perm]
elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
raise NotImplementedError
elif isinstance(value, SparseTensor):
data[key] = value.permute(perm)
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment