test_geometry.py 7.82 KB
Newer Older
1
2
import backend as F
import numpy as np
3
4
import pytest
import torch as th
nv-dlasalle's avatar
nv-dlasalle committed
5
from test_utils import parametrize_idtype
6
7
from test_utils.graph_cases import get_cases

8
9
10
11
12
13
import dgl
import dgl.nn
from dgl import DGLError
from dgl.base import DGLWarning
from dgl.geometry import farthest_point_sampler, neighbor_matching

14
15
16
17
18

def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
19
    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))
20
21
22
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
23
    res = farthest_point_sampler(x, sample_points)
24
25
26
27
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0

28

29
30
31
32
def test_fps_start_idx():
    N = 1000
    batch_size = 5
    sample_points = 10
33
    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))
34
35
36
37
38
39
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
    res = farthest_point_sampler(x, sample_points, start_idx=0)
    assert th.any(res[:, 0] == 0)

40

41
42
def _test_knn_common(device, algorithm, dist, exclude_self):
    x = th.randn(8, 3).to(device)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
43
    kg = dgl.nn.KNNGraph(3)
44
    if dist == "euclidean":
45
46
47
48
49
        d = th.cdist(x, x).to(F.cpu())
    else:
        x = x + th.randn(1).item()
        tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
        d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
50

51
    def check_knn(g, x, start, end, k, exclude_self, check_indices=True):
52
        assert g.device == x.device
53
        g = g.to(F.cpu())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
54
55
56
        for v in range(start, end):
            src, _ = g.in_edges(v)
            src = set(src.numpy())
57
58
59
            assert len(src) == k
            if check_indices:
                i = v - start
60
61
62
63
64
65
66
67
                src_ans = set(
                    th.topk(
                        d[start:end, start:end][i],
                        k + (1 if exclude_self else 0),
                        largest=False,
                    )[1].numpy()
                    + start
                )
68
69
70
71
72
73
74
                if exclude_self:
                    # remove self
                    src_ans.remove(v)
                assert src == src_ans

    def check_batch(g, k, expected_batch_info):
        assert F.array_equal(g.batch_num_nodes(), F.tensor(expected_batch_info))
75
76
77
        assert F.array_equal(
            g.batch_num_edges(), k * F.tensor(expected_batch_info)
        )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
78

79
    # check knn with 2d input
80
81
82
    g = kg(x, algorithm, dist, exclude_self)
    check_knn(g, x, 0, 8, 3, exclude_self)
    check_batch(g, 3, [8])
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
83

84
    # check knn with 3d input
85
86
87
88
    g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
    check_knn(g, x, 0, 4, 3, exclude_self)
    check_knn(g, x, 4, 8, 3, exclude_self)
    check_batch(g, 3, [4, 4])
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
89

90
    # check segmented knn
91
92
93
94
95
96
97
98
    # there are only 2 edges per node possible when exclude_self with 3 nodes in the segment
    # and this test case isn't supposed to warn, so limit it when exclude_self is True
    adjusted_k = 3 - (1 if exclude_self else 0)
    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
    g = kg(x, [3, 5], algorithm, dist, exclude_self)
    check_knn(g, x, 0, 3, adjusted_k, exclude_self)
    check_knn(g, x, 3, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [3, 5])
99
100
101
102

    # check k > num_points
    kg = dgl.nn.KNNGraph(10)
    with pytest.warns(DGLWarning):
103
104
105
106
107
        g = kg(x, algorithm, dist, exclude_self)
    # there are only 7 edges per node possible when exclude_self with 8 nodes total
    adjusted_k = 8 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [8])
108
109

    with pytest.warns(DGLWarning):
110
111
112
113
114
115
        g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
    # there are only 3 edges per node possible when exclude_self with 4 nodes per segment
    adjusted_k = 4 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 4, adjusted_k, exclude_self)
    check_knn(g, x, 4, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [4, 4])
116
117
118

    kg = dgl.nn.SegmentedKNNGraph(5)
    with pytest.warns(DGLWarning):
119
120
121
122
123
124
125
126
        g = kg(x, [3, 5], algorithm, dist, exclude_self)
    # there are only 2 edges per node possible when exclude_self in the segment with
    # only 3 nodes, and the current implementation reduces k for all segments
    # in that case
    adjusted_k = 3 - (1 if exclude_self else 0)
    check_knn(g, x, 0, 3, adjusted_k, exclude_self)
    check_knn(g, x, 3, 8, adjusted_k, exclude_self)
    check_batch(g, adjusted_k, [3, 5])
127
128

    # check k == 0
129
130
131
    # that's valid for exclude_self, but -1 is not, so check -1 instead for exclude_self
    adjusted_k = 0 - (1 if exclude_self else 0)
    kg = dgl.nn.KNNGraph(adjusted_k)
132
    with pytest.raises(DGLError):
133
134
        g = kg(x, algorithm, dist, exclude_self)
    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
135
    with pytest.raises(DGLError):
136
        g = kg(x, [3, 5], algorithm, dist, exclude_self)
137
138
139
140
141

    # check empty
    x_empty = th.tensor([])
    kg = dgl.nn.KNNGraph(3)
    with pytest.raises(DGLError):
142
        g = kg(x_empty, algorithm, dist, exclude_self)
143
144
    kg = dgl.nn.SegmentedKNNGraph(3)
    with pytest.raises(DGLError):
145
        g = kg(x_empty, [3, 5], algorithm, dist, exclude_self)
146

147
148
    # check all coincident points
    x = th.zeros((20, 3)).to(device)
149
    kg = dgl.nn.KNNGraph(3)
150
151
152
153
    g = kg(x, algorithm, dist, exclude_self)
    # different algorithms may break the tie differently, so don't check the indices
    check_knn(g, x, 0, 20, 3, exclude_self, False)
    check_batch(g, 3, [20])
154

155
    # check all coincident points
156
    kg = dgl.nn.SegmentedKNNGraph(3)
157
158
    g = kg(x, [4, 7, 5, 4], algorithm, dist, exclude_self)
    # different algorithms may break the tie differently, so don't check the indices
159
160
    check_knn(g, x, 0, 4, 3, exclude_self, False)
    check_knn(g, x, 4, 11, 3, exclude_self, False)
161
162
163
    check_knn(g, x, 11, 16, 3, exclude_self, False)
    check_knn(g, x, 16, 20, 3, exclude_self, False)
    check_batch(g, 3, [4, 7, 5, 4])
164
165


166
167
168
169
170
@pytest.mark.parametrize(
    "algorithm", ["bruteforce-blas", "bruteforce", "kd-tree"]
)
@pytest.mark.parametrize("dist", ["euclidean", "cosine"])
@pytest.mark.parametrize("exclude_self", [False, True])
171
172
def test_knn_cpu(algorithm, dist, exclude_self):
    _test_knn_common(F.cpu(), algorithm, dist, exclude_self)
173
174


175
176
177
178
179
@pytest.mark.parametrize(
    "algorithm", ["bruteforce-blas", "bruteforce", "bruteforce-sharemem"]
)
@pytest.mark.parametrize("dist", ["euclidean", "cosine"])
@pytest.mark.parametrize("exclude_self", [False, True])
180
181
182
183
def test_knn_cuda(algorithm, dist, exclude_self):
    if not th.cuda.is_available():
        return
    _test_knn_common(F.cuda(), algorithm, dist, exclude_self)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
184

185

nv-dlasalle's avatar
nv-dlasalle committed
186
@parametrize_idtype
187
188
189
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"]))
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("relabel", [True, False])
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def test_edge_coarsening(idtype, g, weight, relabel):
    num_nodes = g.num_nodes()
    g = dgl.to_bidirected(g)
    g = g.astype(idtype).to(F.ctx())
    edge_weight = None
    if weight:
        edge_weight = F.abs(F.randn((g.num_edges(),))).to(F.ctx())
    node_labels = neighbor_matching(g, edge_weight, relabel_idx=relabel)
    unique_ids, counts = th.unique(node_labels, return_counts=True)
    num_result_ids = unique_ids.size(0)

    # shape correct
    assert node_labels.shape == (g.num_nodes(),)

    # all nodes marked
    assert F.reduce_sum(node_labels < 0).item() == 0

    # number of unique node ids correct.
    assert num_result_ids >= num_nodes // 2 and num_result_ids <= num_nodes

    # each unique id has <= 2 nodes
    assert F.reduce_sum(counts > 2).item() == 0

    # if two nodes have the same id, they must be neighbors
    idxs = F.arange(0, num_nodes, idtype)
    for l in unique_ids:
        l = l.item()
        idx = idxs[(node_labels == l)]
        if idx.size(0) == 2:
            u, v = idx[0].item(), idx[1].item()
            assert g.has_edges_between(u, v)


223
if __name__ == "__main__":
224
    test_fps()
225
    test_fps_start_idx()
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
226
    test_knn()