test_graclus.py 1.44 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
import numpy as np
rusty1s's avatar
rusty1s committed
6
7
from torch_cluster import graclus_cluster

rusty1s's avatar
rusty1s committed
8
from .tensor import cpu_tensors
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': None,
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]

rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
def assert_correct_graclus(row, col, cluster):
    row, col, cluster = row.numpy(), col.numpy(), cluster.numpy()

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
    _, count = np.unique(cluster, return_counts=True)
    assert (count > 2).max() == 0

    # Corresponding clusters must be adjacent.
    for n in range(cluster.shape[0]):
rusty1s's avatar
rusty1s committed
33
34
35
36
        x = cluster[col[row == n]] == cluster[n]  # Neighbors with same cluster
        y = cluster == cluster[n]  # Nodes with same cluster
        y[n] = 0  # Do not look at cluster of node `n`.
        assert x.sum() == y.sum()
rusty1s's avatar
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
40
41
@pytest.mark.parametrize('tensor,i', product(cpu_tensors, range(len(tests))))
def test_graclus_cluster_cpu(tensor, i):
    data = tests[i]
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
45
46
47
    row = torch.LongTensor(data['row'])
    col = torch.LongTensor(data['col'])

    weight = data['weight']
    weight = weight if weight is None else getattr(torch, tensor)(weight)
rusty1s's avatar
rusty1s committed
48
49
50

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)