test_graclus.py 1.59 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
5
6
import torch
from torch_cluster import graclus_cluster

rusty1s's avatar
rusty1s committed
7
from .utils import dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
def assert_correct_graclus(row, col, cluster):
rusty1s's avatar
rusty1s committed
20
21
    row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
    n = cluster.size(0)
rusty1s's avatar
rusty1s committed
22
23
24
25
26

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
rusty1s's avatar
rusty1s committed
27
28
29
    _, index = torch.unique(cluster, return_inverse=True)
    count = torch.zeros_like(cluster)
    count.scatter_add_(0, index, torch.ones_like(cluster))
rusty1s's avatar
rusty1s committed
30
31
    assert (count > 2).max() == 0

rusty1s's avatar
rusty1s committed
32
    # Cluster value is minimal.
rusty1s's avatar
rusty1s committed
33
    assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
    # Corresponding clusters must be adjacent.
rusty1s's avatar
rusty1s committed
36
37
38
39
    for i in range(n):
        x = cluster[col[row == i]] == cluster[i]  # Neighbors with same cluster
        y = cluster == cluster[i]  # Nodes with same cluster.
        y[i] = 0  # Do not look at cluster of `i`.
rusty1s's avatar
rusty1s committed
40
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
44
45
46
47
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster_cpu(test, dtype, device):
    row = tensor(test['row'], torch.long, device)
    col = tensor(test['col'], torch.long, device)
    weight = tensor(test.get('weight'), dtype, device)
rusty1s's avatar
rusty1s committed
48
49
50

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)