test_graclus.py 1.67 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
5
import torch
from torch_cluster import graclus_cluster
Matthias Fey's avatar
Matthias Fey committed
6
from torch_cluster.testing import devices, dtypes, tensor
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
17
18

def assert_correct(row, col, cluster):
rusty1s's avatar
rusty1s committed
19
20
    row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
    n = cluster.size(0)
rusty1s's avatar
rusty1s committed
21
22
23
24
25

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
rusty1s's avatar
rusty1s committed
26
27
28
    _, index = torch.unique(cluster, return_inverse=True)
    count = torch.zeros_like(cluster)
    count.scatter_add_(0, index, torch.ones_like(cluster))
rusty1s's avatar
rusty1s committed
29
30
    assert (count > 2).max() == 0

rusty1s's avatar
rusty1s committed
31
    # Cluster value is minimal.
rusty1s's avatar
rusty1s committed
32
    assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
    # Corresponding clusters must be adjacent.
rusty1s's avatar
rusty1s committed
35
36
37
38
    for i in range(n):
        x = cluster[col[row == i]] == cluster[i]  # Neighbors with same cluster
        y = cluster == cluster[i]  # Nodes with same cluster.
        y[i] = 0  # Do not look at cluster of `i`.
rusty1s's avatar
rusty1s committed
39
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
rusty1s's avatar
typo  
rusty1s committed
43
def test_graclus_cluster(test, dtype, device):
Matthias Fey's avatar
Matthias Fey committed
44
45
46
    if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
        return

rusty1s's avatar
rusty1s committed
47
48
49
    row = tensor(test['row'], torch.long, device)
    col = tensor(test['col'], torch.long, device)
    weight = tensor(test.get('weight'), dtype, device)
rusty1s's avatar
rusty1s committed
50
51

    cluster = graclus_cluster(row, col, weight)
rusty1s's avatar
rusty1s committed
52
    assert_correct(row, col, cluster)