test_nearest_neighbor_points.py 2.8 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
import torch

from pytorch3d import _C


class TestNearestNeighborPoints(unittest.TestCase):
    @staticmethod
    def nn_points_idx_naive(x, y):
        """
        PyTorch implementation of nn_points_idx function.
        """
        N, P1, D = x.shape
        _N, P2, _D = y.shape
        assert N == _N and D == _D
        diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D)
        dists2 = (diffs * diffs).sum(3)
        idx = dists2.argmin(2)
        return idx

24
    def _test_nn_helper(self, device):
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        for D in [3, 4]:
            for N in [1, 4]:
                for P1 in [1, 8, 64, 128]:
                    for P2 in [32, 128]:
                        x = torch.randn(N, P1, D, device=device)
                        y = torch.randn(N, P2, D, device=device)

                        # _C.nn_points_idx should dispatch
                        # to the cpp or cuda versions of the function
                        # depending on the input type.
                        idx1 = _C.nn_points_idx(x, y)
                        idx2 = TestNearestNeighborPoints.nn_points_idx_naive(
                            x, y
                        )
                        self.assertTrue(idx1.size(1) == P1)
                        self.assertTrue(torch.all(idx1 == idx2))

42
43
44
45
46
47
48
49
    def test_nn_cuda(self):
        """
        Test cuda output vs naive python implementation.
        """
        device = torch.device('cuda:0')
        self._test_nn_helper(device)

    def test_nn_cpu(self):
facebook-github-bot's avatar
facebook-github-bot committed
50
        """
51
        Test cpu output vs naive python implementation
facebook-github-bot's avatar
facebook-github-bot committed
52
        """
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        device = torch.device('cpu')
        self._test_nn_helper(device)

    @staticmethod
    def bm_nn_points_cpu_with_init(
        N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
    ):
        device = torch.device('cpu')
        x = torch.randn(N, P1, D, device=device)
        y = torch.randn(N, P2, D, device=device)

        def nn_cpu():
            _C.nn_points_idx(x.contiguous(), y.contiguous())

        return nn_cpu
facebook-github-bot's avatar
facebook-github-bot committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    @staticmethod
    def bm_nn_points_cuda_with_init(
        N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
    ):
        device = torch.device("cuda:0")
        x = torch.randn(N, P1, D, device=device)
        y = torch.randn(N, P2, D, device=device)
        torch.cuda.synchronize()

        def nn_cpp():
            _C.nn_points_idx(x.contiguous(), y.contiguous())
            torch.cuda.synchronize()

        return nn_cpp

    @staticmethod
    def bm_nn_points_python_with_init(
        N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
    ):
        x = torch.randn(N, P1, D)
        y = torch.randn(N, P2, D)

        def nn_python():
            TestNearestNeighborPoints.nn_points_idx_naive(x, y)

        return nn_python