test_geometry.py 528 Bytes
Newer Older
1
import mxnet as mx
2
from dgl.geometry import farthest_point_sampler
3
4
5
6
7
8
9
10
11
12
13
14
import backend as F

import numpy as np

def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = mx.nd.array(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.as_in_context(ctx)
15
    res = farthest_point_sampler(x, sample_points)
16
17
18
19
20
21
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0

if __name__ == '__main__':
    test_fps()