test_metis.py 1.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices

rusty1s's avatar
rusty1s committed
7
8
9
10
11
try:
    torch.ops.torch_sparse.partition
    with_metis = True
except RuntimeError:
    with_metis = False
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14

@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
rusty1s's avatar
rusty1s committed
15
16
@pytest.mark.parametrize('device', devices)
def test_metis(device):
rusty1s's avatar
rusty1s committed
17
18
19
    value1 = torch.randn(6 * 6, device=device).view(6, 6)
    value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    value3 = torch.ones(6 * 6, device=device).view(6, 6)
20

21
22
    vwgts = torch.rand(6, device=device)

rusty1s's avatar
rusty1s committed
23
    for value in [value1, value2, value3]:
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        for vwgt in [None, vwgts]:
            mat = SparseTensor.from_dense(value)

            _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                             vweights=vwgt,
                                             weighted=True)
            assert partptr.numel() == 3
            assert perm.numel() == 6

            _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                             vweights=vwgt,
                                             weighted=False)
            assert partptr.numel() == 3
            assert perm.numel() == 6

            _, partptr, perm = mat.partition(num_parts=1, recursive=False,
                                             vweights=vwgt,
                                             weighted=True)
            assert partptr.numel() == 2
            assert perm.numel() == 6