test_graclus.py 2.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
import numpy as np
rusty1s's avatar
rusty1s committed
6
7
from torch_cluster import graclus_cluster

rusty1s's avatar
rusty1s committed
8
from .tensor import tensors
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': None,
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
def assert_correct_graclus(row, col, cluster):
rusty1s's avatar
rusty1s committed
22
23
    row, col = row.cpu().numpy(), col.cpu().numpy()
    cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0)
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
    _, count = np.unique(cluster, return_counts=True)
    assert (count > 2).max() == 0

rusty1s's avatar
rusty1s committed
32
33
34
    # Cluster value is minimal.
    assert (cluster <= np.arange(n_nodes, dtype=row.dtype)).sum() == n_nodes

rusty1s's avatar
rusty1s committed
35
36
    # Corresponding clusters must be adjacent.
    for n in range(cluster.shape[0]):
rusty1s's avatar
rusty1s committed
37
38
39
40
        x = cluster[col[row == n]] == cluster[n]  # Neighbors with same cluster
        y = cluster == cluster[n]  # Nodes with same cluster
        y[n] = 0  # Do not look at cluster of node `n`.
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
rusty1s's avatar
rusty1s committed
44
45
def test_graclus_cluster_cpu(tensor, i):
    data = tests[i]
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
49
50
51
    row = torch.LongTensor(data['row'])
    col = torch.LongTensor(data['col'])

    weight = data['weight']
    weight = weight if weight is None else getattr(torch, tensor)(weight)
rusty1s's avatar
rusty1s committed
52
53
54

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)
rusty1s's avatar
rusty1s committed
55
56


rusty1s's avatar
rusty1s committed
57
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
rusty1s's avatar
rusty1s committed
58
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
rusty1s's avatar
rusty1s committed
59
def test_graclus_cluster_gpu(tensor, i):  # pragma: no cover
rusty1s's avatar
rusty1s committed
60
61
62
63
    data = tests[i]

    row = torch.cuda.LongTensor(data['row'])
    col = torch.cuda.LongTensor(data['col'])
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
67
68
69
    weight = data['weight']
    weight = weight if weight is None else getattr(torch.cuda, tensor)(weight)

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)