Commit 203d69f2 authored by bowendeng's avatar bowendeng
Browse files

Merge remote-tracking branch 'origin/master'

parents 46f7e1e8 2fd3c005
...@@ -6,3 +6,4 @@ dist/ ...@@ -6,3 +6,4 @@ dist/
*.egg-info/ *.egg-info/
.coverage .coverage
*.so *.so
.idea/
\ No newline at end of file
...@@ -7,11 +7,16 @@ from .utils import devices ...@@ -7,11 +7,16 @@ from .utils import devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_metis(device): def test_metis(device):
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device)) weighted_mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = mat.partition(num_parts=2, recursive=False) mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
mat, partptr, perm = mat.partition(num_parts=2, recursive=True) mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3
assert perm.numel() == 6
unweighted_mat = SparseTensor.from_dense(torch.ones((6, 6), device=device))
mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment