test_graclus.py 1.67 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
5
6
import torch
from torch_cluster import graclus_cluster

rusty1s's avatar
rusty1s committed
7
from .utils import dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
18
19
20
devices = [torch.device('cpu')]
dtypes = [torch.float]
tests = [tests[0]]
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23

def assert_correct(row, col, cluster):
rusty1s's avatar
rusty1s committed
24
25
    row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
    n = cluster.size(0)
rusty1s's avatar
rusty1s committed
26
27
28
29
30

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
rusty1s's avatar
rusty1s committed
31
32
33
    _, index = torch.unique(cluster, return_inverse=True)
    count = torch.zeros_like(cluster)
    count.scatter_add_(0, index, torch.ones_like(cluster))
rusty1s's avatar
rusty1s committed
34
35
    assert (count > 2).max() == 0

rusty1s's avatar
rusty1s committed
36
    # Cluster value is minimal.
rusty1s's avatar
rusty1s committed
37
    assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
    # Corresponding clusters must be adjacent.
rusty1s's avatar
rusty1s committed
40
41
42
43
    for i in range(n):
        x = cluster[col[row == i]] == cluster[i]  # Neighbors with same cluster
        y = cluster == cluster[i]  # Nodes with same cluster.
        y[i] = 0  # Do not look at cluster of `i`.
rusty1s's avatar
rusty1s committed
44
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
rusty1s's avatar
typo  
rusty1s committed
48
def test_graclus_cluster(test, dtype, device):
rusty1s's avatar
rusty1s committed
49
50
51
    row = tensor(test['row'], torch.long, device)
    col = tensor(test['col'], torch.long, device)
    weight = tensor(test.get('weight'), dtype, device)
rusty1s's avatar
rusty1s committed
52
53

    cluster = graclus_cluster(row, col, weight)
rusty1s's avatar
rusty1s committed
54
55
    print(cluster)
    # assert_correct(row, col, cluster)