"awq/vscode:/vscode.git/clone" did not exist on "6f516b8d4c154f79c2f86fbd8f702dd7584df2d3"
test_three_nn.py 2.68 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import pytest
import torch

from mmcv.ops import three_nn
6
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
7
8


9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_three_nn(device):
    known = torch.tensor(
        [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
          [-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867],
          [-1.8373, 3.5605, -0.7867]],
         [[-1.3399, 1.9991, -0.3698], [-0.0799, 0.9698, -0.8457],
          [0.0858, 2.4721, -0.1928], [-1.3399, 1.9991, -0.3698],
          [-1.3399, 1.9991, -0.3698]]],
        device=device)
28

29
30
31
32
33
34
35
36
37
38
39
40
    unknown = torch.tensor(
        [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
          [-0.6503, 3.6637, -1.0622], [-1.5237, 2.3976, -0.8097],
          [-0.0722, 3.4017, -0.2880], [0.5198, 3.0661, -0.4605],
          [-2.0185, 3.5019, -0.3236], [0.5098, 3.1020, 0.5799],
          [-1.6137, 3.8443, -0.5269], [0.7341, 2.9626, -0.3189]],
         [[-1.3399, 1.9991, -0.3698], [-0.0799, 0.9698, -0.8457],
          [0.0858, 2.4721, -0.1928], [-0.9022, 1.6560, -1.3090],
          [0.1156, 1.6901, -0.4366], [-0.6477, 2.3576, -0.1563],
          [-0.8482, 1.1466, -1.2704], [-0.8753, 2.0845, -0.3460],
          [-0.5621, 1.4233, -1.2858], [-0.5883, 1.3114, -1.2899]]],
        device=device)
41
42

    dist, idx = three_nn(unknown, known)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    expected_dist = torch.tensor(
        [[[0.0000, 0.0000, 0.0000], [0.0000, 2.0463, 2.8588],
          [0.0000, 1.2229, 1.2229], [1.2047, 1.2047, 1.2047],
          [1.0011, 1.0845, 1.8411], [0.7433, 1.4451, 2.4304],
          [0.5007, 0.5007, 0.5007], [0.4587, 2.0875, 2.7544],
          [0.4450, 0.4450, 0.4450], [0.5514, 1.7206, 2.6811]],
         [[0.0000, 0.0000, 0.0000], [0.0000, 1.6464, 1.6952],
          [0.0000, 1.5125, 1.5125], [1.0915, 1.0915, 1.0915],
          [0.8197, 0.8511, 1.4894], [0.7433, 0.8082, 0.8082],
          [0.8955, 1.3340, 1.3340], [0.4730, 0.4730, 0.4730],
          [0.7949, 1.3325, 1.3325], [0.7566, 1.3727, 1.3727]]],
        device=device)
    expected_idx = torch.tensor(
        [[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0], [1, 2, 0],
          [0, 3, 4], [1, 2, 0], [0, 3, 4], [1, 2, 0]],
         [[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0], [2, 0, 3],
          [1, 0, 3], [0, 3, 4], [1, 0, 3], [1, 0, 3]]],
        device=device)
61

62
    assert torch.allclose(dist, expected_dist, atol=1e-4)
63
    assert torch.all(idx == expected_idx)