test_metis.py 1003 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices


@pytest.mark.parametrize('device', devices)
def test_metis(device):
bwdeng20's avatar
bwdeng20 committed
10
    weighted_mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
11
12
13
14
15
    mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=True)
    assert partptr.numel() == 3
    assert perm.numel() == 6

    mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=False)
rusty1s's avatar
rusty1s committed
16
17
18
    assert partptr.numel() == 3
    assert perm.numel() == 6

rusty1s's avatar
update  
rusty1s committed
19
    unweighted_mat = SparseTensor.from_dense(torch.ones((6, 6), device=device))
20
    mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True, sort_strategy=True)
bwdeng20's avatar
bwdeng20 committed
21
22
23
    assert partptr.numel() == 3
    assert perm.numel() == 6

rusty1s's avatar
update  
rusty1s committed
24
    unweighted_mat = unweighted_mat.set_value(None)
bwdeng20's avatar
bwdeng20 committed
25
    mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True)
rusty1s's avatar
rusty1s committed
26
    assert partptr.numel() == 3
rusty1s's avatar
typo  
rusty1s committed
27
    assert perm.numel() == 6