test_metis.py 1.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices

rusty1s's avatar
rusty1s committed
7
8
9
10
11
try:
    torch.ops.torch_sparse.partition
    with_metis = True
except RuntimeError:
    with_metis = False
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14

@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
rusty1s's avatar
rusty1s committed
15
16
@pytest.mark.parametrize('device', devices)
def test_metis(device):
rusty1s's avatar
rusty1s committed
17
18
19
    value1 = torch.randn(6 * 6, device=device).view(6, 6)
    value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    value3 = torch.ones(6 * 6, device=device).view(6, 6)
20

rusty1s's avatar
rusty1s committed
21
22
    for value in [value1, value2, value3]:
        mat = SparseTensor.from_dense(value)
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
25
26
27
        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=True)
        assert partptr.numel() == 3
        assert perm.numel() == 6
bwdeng20's avatar
bwdeng20 committed
28

rusty1s's avatar
rusty1s committed
29
30
31
32
        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=False)
        assert partptr.numel() == 3
        assert perm.numel() == 6
33
34
35
36
37

        _, partptr, perm = mat.partition(num_parts=1, recursive=False,
                                         weighted=True)
        assert partptr.numel() == 2
        assert perm.numel() == 6