Commit eb8c2ec0 authored by rusty1s's avatar rusty1s
Browse files

cleanup

parent 45a4d985
......@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
weighted=False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
assert num_parts >= 1
if num_parts == 1:
partptr = torch.tensor([0, src.size(0)], device=src.device())
perm = torch.arange(src.size(0), device=src.device())
return src, partptr, perm
rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu()
......@@ -33,13 +40,8 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
else:
value = None
if num_parts > 1:
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
num_parts, recursive)
elif num_parts == 1:
cluster = torch.zeros((src.size(0)), dtype=torch.long)
else:
raise ValueError
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive)
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment