test_knn.py 2.41 KB
Newer Older
Justin Johnson's avatar
Justin Johnson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
from itertools import product
import torch

from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx


class TestKNN(unittest.TestCase):
    def _check_knn_result(self, out1, out2, sorted):
        # When sorted=True, points should be sorted by distance and should
        # match between implementations. When sorted=False we we only want to
        # check that we got the same set of indices, so we sort the indices by
        # index value.
        idx1, dist1 = out1
        idx2, dist2 = out2
        if not sorted:
            idx1 = idx1.sort(dim=2).values
            idx2 = idx2.sort(dim=2).values
            dist1 = dist1.sort(dim=2).values
            dist2 = dist2.sort(dim=2).values
        if not torch.all(idx1 == idx2):
            print(idx1)
            print(idx2)
        self.assertTrue(torch.all(idx1 == idx2))
        self.assertTrue(torch.allclose(dist1, dist2))

    def test_knn_vs_python_cpu(self):
        """ Test CPU output vs PyTorch implementation """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
31
        device = torch.device("cpu")
Justin Johnson's avatar
Justin Johnson committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        Ns = [1, 4]
        Ds = [2, 3]
        P1s = [1, 10, 101]
        P2s = [10, 101]
        Ks = [1, 3, 10]
        sorts = [True, False]
        factors = [Ns, Ds, P1s, P2s, Ks, sorts]
        for N, D, P1, P2, K, sort in product(*factors):
            x = torch.randn(N, P1, D, device=device)
            y = torch.randn(N, P2, D, device=device)
            out1 = _knn_points_idx_naive(x, y, K, sort)
            out2 = knn_points_idx(x, y, K, sort)
            self._check_knn_result(out1, out2, sort)

    def test_knn_vs_python_cuda(self):
        """ Test CUDA output vs PyTorch implementation """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
48
        device = torch.device("cuda")
Justin Johnson's avatar
Justin Johnson committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        Ns = [1, 4]
        Ds = [2, 3, 8]
        P1s = [1, 8, 64, 128, 1001]
        P2s = [32, 128, 513]
        Ks = [1, 3, 10]
        sorts = [True, False]
        versions = [0, 1, 2, 3]
        factors = [Ns, Ds, P1s, P2s, Ks, sorts]
        for N, D, P1, P2, K, sort in product(*factors):
            x = torch.randn(N, P1, D, device=device)
            y = torch.randn(N, P2, D, device=device)
            out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
            for version in versions:
                if version == 3 and K > 4:
                    continue
                out2 = knn_points_idx(x, y, K, sort, version)
                self._check_knn_result(out1, out2, sort)