test_graclus.py 949 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import numpy as np
rusty1s's avatar
rusty1s committed
3
4
5
from torch_cluster import graclus_cluster


rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def assert_correct_graclus(row, col, cluster):
    row, col, cluster = row.numpy(), col.numpy(), cluster.numpy()

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
    _, count = np.unique(cluster, return_counts=True)
    assert (count > 2).max() == 0

    # Corresponding clusters must be adjacent.
    for n in range(cluster.shape[0]):
        assert (cluster[col[row == n]] == cluster[n]).max() == 1


rusty1s's avatar
rusty1s committed
21
22
23
def test_graclus_cluster_cpu():
    row = torch.LongTensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
    col = torch.LongTensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
rusty1s's avatar
rusty1s committed
24
25
    weight = torch.Tensor([1, 2, 1, 3, 2, 2, 3, 1, 2, 1])

rusty1s's avatar
rusty1s committed
26
    cluster = graclus_cluster(row, col)
rusty1s's avatar
rusty1s committed
27
28
29
30
    assert_correct_graclus(row, col, cluster)

    cluster = graclus_cluster(row, col, weight)
    assert_correct_graclus(row, col, cluster)