test_geometry.py 546 Bytes
Newer Older
1
import backend as F
2
import mxnet as mx
3
4
import numpy as np

5
from dgl.geometry import farthest_point_sampler
6
7
8
9
10
11


def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
12
13
14
    x = mx.nd.array(
        np.random.uniform(size=(batch_size, int(N / batch_size), 3))
    )
15
16
17
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.as_in_context(ctx)
18
    res = farthest_point_sampler(x, sample_points)
19
20
21
22
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0

23
24

if __name__ == "__main__":
25
    test_fps()