test_metis.py 1.27 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
from itertools import product

rusty1s's avatar
rusty1s committed
3
import pytest
rusty1s's avatar
rusty1s committed
4
5
6
7
8
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices

rusty1s's avatar
rusty1s committed
9
try:
rusty1s's avatar
rusty1s committed
10
11
12
    rowptr = torch.tensor([0, 1])
    col = torch.tensor([0])
    torch.ops.torch_sparse.partition(rowptr, col, None, 1)
rusty1s's avatar
rusty1s committed
13
14
15
    with_metis = True
except RuntimeError:
    with_metis = False
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18

@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
rusty1s's avatar
rename  
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device, weighted):
    mat1 = torch.randn(6 * 6, device=device).view(6, 6)
    mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    mat3 = torch.ones(6 * 6, device=device).view(6, 6)

    vec1 = None
    vec2 = torch.rand(6, device=device)

    for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
        mat = SparseTensor.from_dense(mat)

        _, partptr, perm = mat.partition(num_parts=1, recursive=False,
                                         weighted=weighted, node_weight=vec)
        assert partptr.numel() == 2
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=weighted, node_weight=vec)
        assert partptr.numel() == 3
        assert perm.numel() == 6