Unverified Commit 468aea5b authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #68 from james77777778/master

Fixed the bug from metis if num_parts == 1
parents 6884191b eb8c2ec0
......@@ -30,3 +30,8 @@ def test_metis(device):
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=True)
assert partptr.numel() == 2
assert perm.numel() == 6
......@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
weighted=False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
assert num_parts >= 1
if num_parts == 1:
partptr = torch.tensor([0, src.size(0)], device=src.device())
perm = torch.arange(src.size(0), device=src.device())
return src, partptr, perm
rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment