test_graclus.py 2.09 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
import numpy as np
rusty1s's avatar
rusty1s committed
6
7
from torch_cluster import graclus_cluster

rusty1s's avatar
rusty1s committed
8
from .tensor import tensors
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
def assert_correct_graclus(row, col, cluster):
rusty1s's avatar
rusty1s committed
21
22
    row, col = row.cpu().numpy(), col.cpu().numpy()
    cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0)
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
    _, count = np.unique(cluster, return_counts=True)
    assert (count > 2).max() == 0

rusty1s's avatar
rusty1s committed
31
32
33
    # Cluster value is minimal.
    assert (cluster <= np.arange(n_nodes, dtype=row.dtype)).sum() == n_nodes

rusty1s's avatar
rusty1s committed
34
35
    # Corresponding clusters must be adjacent.
    for n in range(cluster.shape[0]):
rusty1s's avatar
rusty1s committed
36
37
38
39
        x = cluster[col[row == n]] == cluster[n]  # Neighbors with same cluster
        y = cluster == cluster[n]  # Nodes with same cluster
        y[n] = 0  # Do not look at cluster of node `n`.
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
rusty1s's avatar
rusty1s committed
43
44
def test_graclus_cluster_cpu(tensor, i):
    data = tests[i]
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
47
48
    row = torch.LongTensor(data['row'])
    col = torch.LongTensor(data['col'])

rusty1s's avatar
rusty1s committed
49
    weight = data.get('weight')
rusty1s's avatar
rusty1s committed
50
    weight = weight if weight is None else getattr(torch, tensor)(weight)
rusty1s's avatar
rusty1s committed
51
52
53

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)
rusty1s's avatar
rusty1s committed
54
55


rusty1s's avatar
rusty1s committed
56
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
rusty1s's avatar
rusty1s committed
57
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
rusty1s's avatar
rusty1s committed
58
def test_graclus_cluster_gpu(tensor, i):  # pragma: no cover
rusty1s's avatar
rusty1s committed
59
60
61
62
    data = tests[i]

    row = torch.cuda.LongTensor(data['row'])
    col = torch.cuda.LongTensor(data['col'])
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
    weight = data.get('weight')
rusty1s's avatar
rusty1s committed
65
66
67
68
    weight = weight if weight is None else getattr(torch.cuda, tensor)(weight)

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)