Commit f8697f48 authored by rusty1s's avatar rusty1s
Browse files

fix test

parent a06899bb
...@@ -8,6 +8,10 @@ from .utils import devices ...@@ -8,6 +8,10 @@ from .utils import devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_metis(device): def test_metis(device):
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device)) mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = mat.partition_kway(num_parts=2) mat, partptr, perm = mat.partition(num_parts=2, recursive=False)
assert partptr.numel() == 3
assert perm.numel() == 6
mat, partptr, perm = mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
...@@ -21,4 +21,5 @@ def partition( ...@@ -21,4 +21,5 @@ def partition(
return out, partptr, perm return out, partptr, perm
SparseTensor.partition = lambda self, num_parts: partition(self, num_parts) SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
self, num_parts, recursive)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment