test_geometry.py 7.63 KB
Newer Older
1
import backend as F
2
3
import dgl.nn
import dgl
4
import numpy as np
5
6
import pytest
import torch as th
7
8
from dgl import DGLError
from dgl.base import DGLWarning
9
from dgl.geometry import neighbor_matching, farthest_point_sampler
nv-dlasalle's avatar
nv-dlasalle committed
10
from test_utils import parametrize_idtype
11
12
from test_utils.graph_cases import get_cases

13
14
15
16
17
18
19
20
21

def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
22
    res = farthest_point_sampler(x, sample_points)
23
24
25
26
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0

27

28
29
30
31
32
33
34
35
36
37
38
def test_fps_start_idx():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
    res = farthest_point_sampler(x, sample_points, start_idx=0)
    assert th.any(res[:, 0] == 0)

39
40
def _test_knn_common(device, algorithm, dist, exclude_self):
    x = th.randn(8, 3).to(device)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
41
    kg = dgl.nn.KNNGraph(3)
42
43
44
45
46
47
    if dist == 'euclidean':
        d = th.cdist(x, x).to(F.cpu())
    else:
        x = x + th.randn(1).item()
        tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
        d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
48

49
    def check_knn(g, x, start, end, k, exclude_self, check_indices=True):
50
        assert g.device == x.device
51
        g = g.to(F.cpu())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
52
53
54
        for v in range(start, end):
            src, _ = g.in_edges(v)
            src = set(src.numpy())
55
56
57
58
59
60
61
62
63
64
65
66
            assert len(src) == k
            if check_indices:
                i = v - start
                src_ans = set(th.topk(d[start:end, start:end][i], k + (1 if exclude_self else 0), largest=False)[1].numpy() + start)
                if exclude_self:
                    # remove self
                    src_ans.remove(v)
                assert src == src_ans

    def check_batch(g, k, expected_batch_info):
        assert F.array_equal(g.batch_num_nodes(), F.tensor(expected_batch_info))
        assert F.array_equal(g.batch_num_edges(), k*F.tensor(expected_batch_info))
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
67

68
    # check knn with 2d input
69
70
71
    g = kg(x, algorithm, dist, exclude_self)
    check_knn(g, x, 0, 8, 3, exclude_self)
    check_batch(g, 3, [8])
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
72

73
    # check knn with 3d input
74
75
76
77
    g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
    check_knn(g, x, 0, 4, 3, exclude_self)
    check_knn(g, x, 4, 8, 3, exclude_self)
    check_batch(g, 3, [4, 4])
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
78

79
    # check segmented knn
80
81
82
83
84
85
86
87
    # there are only 2 edges per node possible when exclude_self with 3 nodes in the segment
    # and this test case isn't supposed to warn, so limit it when exclude_self is True
    adjusted_k = 3 - (1 if exclude_self else 0)
    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
    g = kg(x, [3, 5], algorithm, dist, exclude_self)
    check_knn(g, x, 0, 3, adjusted_k, exclude_self)
    check_knn(g, x, 3, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [3, 5])
88
89
90
91

    # check k > num_points
    kg = dgl.nn.KNNGraph(10)
    with pytest.warns(DGLWarning):
92
93
94
95
96
        g = kg(x, algorithm, dist, exclude_self)
    # there are only 7 edges per node possible when exclude_self with 8 nodes total
    adjusted_k = 8 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [8])
97
98

    with pytest.warns(DGLWarning):
99
100
101
102
103
104
        g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
    # there are only 3 edges per node possible when exclude_self with 4 nodes per segment
    adjusted_k = 4 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 4, adjusted_k, exclude_self)
    check_knn(g, x, 4, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [4, 4])
105
106
107

    kg = dgl.nn.SegmentedKNNGraph(5)
    with pytest.warns(DGLWarning):
108
109
110
111
112
113
114
115
        g = kg(x, [3, 5], algorithm, dist, exclude_self)
    # there are only 2 edges per node possible when exclude_self in the segment with
    # only 3 nodes, and the current implementation reduces k for all segments
    # in that case
    adjusted_k = 3 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 3, adjusted_k, exclude_self)
    check_knn(g, x, 3, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [3, 5])
116
117

    # check k == 0
118
119
120
    # that's valid for exclude_self, but -1 is not, so check -1 instead for exclude_self
    adjusted_k = 0 - (1 if exclude_self else 0)
    kg = dgl.nn.KNNGraph(adjusted_k)
121
    with pytest.raises(DGLError):
122
123
        g = kg(x, algorithm, dist, exclude_self)
    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
124
    with pytest.raises(DGLError):
125
        g = kg(x, [3, 5], algorithm, dist, exclude_self)
126
127
128
129
130

    # check empty
    x_empty = th.tensor([])
    kg = dgl.nn.KNNGraph(3)
    with pytest.raises(DGLError):
131
        g = kg(x_empty, algorithm, dist, exclude_self)
132
133
    kg = dgl.nn.SegmentedKNNGraph(3)
    with pytest.raises(DGLError):
134
        g = kg(x_empty, [3, 5], algorithm, dist, exclude_self)
135

136
137
    # check all coincident points
    x = th.zeros((20, 3)).to(device)
138
    kg = dgl.nn.KNNGraph(3)
139
140
141
142
    g = kg(x, algorithm, dist, exclude_self)
    # different algorithms may break the tie differently, so don't check the indices
    check_knn(g, x, 0, 20, 3, exclude_self, False)
    check_batch(g, 3, [20])
143

144
    # check all coincident points
145
    kg = dgl.nn.SegmentedKNNGraph(3)
146
147
148
149
150
151
152
    g = kg(x, [4, 7, 5, 4], algorithm, dist, exclude_self)
    # different algorithms may break the tie differently, so don't check the indices
    check_knn(g, x,  0,  4, 3, exclude_self, False)
    check_knn(g, x,  4, 11, 3, exclude_self, False)
    check_knn(g, x, 11, 16, 3, exclude_self, False)
    check_knn(g, x, 16, 20, 3, exclude_self, False)
    check_batch(g, 3, [4, 7, 5, 4])
153
154


155
156
157
158
159
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
@pytest.mark.parametrize('exclude_self', [False, True])
def test_knn_cpu(algorithm, dist, exclude_self):
    _test_knn_common(F.cpu(), algorithm, dist, exclude_self)
160
161


162
163
164
165
166
167
168
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
@pytest.mark.parametrize('exclude_self', [False, True])
def test_knn_cuda(algorithm, dist, exclude_self):
    if not th.cuda.is_available():
        return
    _test_knn_common(F.cuda(), algorithm, dist, exclude_self)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
169

170

nv-dlasalle's avatar
nv-dlasalle committed
171
@parametrize_idtype
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['dglgraph']))
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('relabel', [True, False])
def test_edge_coarsening(idtype, g, weight, relabel):
    num_nodes = g.num_nodes()
    g = dgl.to_bidirected(g)
    g = g.astype(idtype).to(F.ctx())
    edge_weight = None
    if weight:
        edge_weight = F.abs(F.randn((g.num_edges(),))).to(F.ctx())
    node_labels = neighbor_matching(g, edge_weight, relabel_idx=relabel)
    unique_ids, counts = th.unique(node_labels, return_counts=True)
    num_result_ids = unique_ids.size(0)

    # shape correct
    assert node_labels.shape == (g.num_nodes(),)

    # all nodes marked
    assert F.reduce_sum(node_labels < 0).item() == 0

    # number of unique node ids correct.
    assert num_result_ids >= num_nodes // 2 and num_result_ids <= num_nodes

    # each unique id has <= 2 nodes
    assert F.reduce_sum(counts > 2).item() == 0

    # if two nodes have the same id, they must be neighbors
    idxs = F.arange(0, num_nodes, idtype)
    for l in unique_ids:
        l = l.item()
        idx = idxs[(node_labels == l)]
        if idx.size(0) == 2:
            u, v = idx[0].item(), idx[1].item()
            assert g.has_edges_between(u, v)


208
209
if __name__ == '__main__':
    test_fps()
210
    test_fps_start_idx()
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
211
    test_knn()