test_metis.py 385 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices


@pytest.mark.parametrize('device', devices)
def test_metis(device):
    mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
    mat, partptr, perm = mat.partition_kway(num_parts=2)
    assert partptr.tolist() == [0, 3, 6]
    assert perm.tolist() == [0, 1, 2, 3, 4, 5]