test_metis.py 846 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices


@pytest.mark.parametrize('device', devices)
def test_metis(device):
rusty1s's avatar
rusty1s committed
10
11
12
    value1 = torch.randn(6 * 6, device=device).view(6, 6)
    value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    value3 = torch.ones(6 * 6, device=device).view(6, 6)
13

rusty1s's avatar
rusty1s committed
14
15
    for value in [value1, value2, value3]:
        mat = SparseTensor.from_dense(value)
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
20
        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=True)
        assert partptr.numel() == 3
        assert perm.numel() == 6
bwdeng20's avatar
bwdeng20 committed
21

rusty1s's avatar
rusty1s committed
22
23
24
25
        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=False)
        assert partptr.numel() == 3
        assert perm.numel() == 6