Commit 45a4d985 authored by james77777778's avatar james77777778
Browse files

fix the bug from metis if num_parts == 1

parent 6884191b
......@@ -30,3 +30,8 @@ def test_metis(device):
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=True)
assert partptr.numel() == 2
assert perm.numel() == 6
......@@ -33,8 +33,13 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
else:
value = None
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive)
if num_parts > 1:
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
num_parts, recursive)
elif num_parts == 1:
cluster = torch.zeros((src.size(0)), dtype=torch.long)
else:
raise ValueError
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment