test_knn.py 4.64 KB
Newer Older
Justin Johnson's avatar
Justin Johnson committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
from itertools import product

6
import torch
Justin Johnson's avatar
Justin Johnson committed
7
8
9
10
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx


class TestKNN(unittest.TestCase):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
12
13
14
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(1)

Justin Johnson's avatar
Justin Johnson committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    def _check_knn_result(self, out1, out2, sorted):
        # When sorted=True, points should be sorted by distance and should
        # match between implementations. When sorted=False we we only want to
        # check that we got the same set of indices, so we sort the indices by
        # index value.
        idx1, dist1 = out1
        idx2, dist2 = out2
        if not sorted:
            idx1 = idx1.sort(dim=2).values
            idx2 = idx2.sort(dim=2).values
            dist1 = dist1.sort(dim=2).values
            dist2 = dist2.sort(dim=2).values
        if not torch.all(idx1 == idx2):
            print(idx1)
            print(idx2)
        self.assertTrue(torch.all(idx1 == idx2))
        self.assertTrue(torch.allclose(dist1, dist2))

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
33
    def test_knn_vs_python_cpu_square(self):
Justin Johnson's avatar
Justin Johnson committed
34
        """ Test CPU output vs PyTorch implementation """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
35
        device = torch.device("cpu")
Justin Johnson's avatar
Justin Johnson committed
36
37
38
39
40
41
42
43
        Ns = [1, 4]
        Ds = [2, 3]
        P1s = [1, 10, 101]
        P2s = [10, 101]
        Ks = [1, 3, 10]
        sorts = [True, False]
        factors = [Ns, Ds, P1s, P2s, Ks, sorts]
        for N, D, P1, P2, K, sort in product(*factors):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
44
45
            lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
            lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
Justin Johnson's avatar
Justin Johnson committed
46
47
            x = torch.randn(N, P1, D, device=device)
            y = torch.randn(N, P2, D, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
48
49
50
51
52
53
            out1 = _knn_points_idx_naive(
                x, y, lengths1=lengths1, lengths2=lengths2, K=K
            )
            out2 = knn_points_idx(
                x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
            )
Justin Johnson's avatar
Justin Johnson committed
54
55
            self._check_knn_result(out1, out2, sort)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
56
    def test_knn_vs_python_cuda_square(self):
Justin Johnson's avatar
Justin Johnson committed
57
        """ Test CUDA output vs PyTorch implementation """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
58
        device = torch.device("cuda")
Justin Johnson's avatar
Justin Johnson committed
59
60
61
62
63
64
65
66
67
68
69
        Ns = [1, 4]
        Ds = [2, 3, 8]
        P1s = [1, 8, 64, 128, 1001]
        P2s = [32, 128, 513]
        Ks = [1, 3, 10]
        sorts = [True, False]
        versions = [0, 1, 2, 3]
        factors = [Ns, Ds, P1s, P2s, Ks, sorts]
        for N, D, P1, P2, K, sort in product(*factors):
            x = torch.randn(N, P1, D, device=device)
            y = torch.randn(N, P2, D, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
            for version in versions:
                if version == 3 and K > 4:
                    continue
                out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
                self._check_knn_result(out1, out2, sort)

    def test_knn_vs_python_cpu_ragged(self):
        device = torch.device("cpu")
        lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
        lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
        N = 4
        D = 3
        Ks = [1, 9, 10, 11, 101]
        sorts = [False, True]
        factors = [Ks, sorts]
        for K, sort in product(*factors):
            x = torch.randn(N, lengths1.max(), D, device=device)
            y = torch.randn(N, lengths2.max(), D, device=device)
            out1 = _knn_points_idx_naive(
                x, y, lengths1=lengths1, lengths2=lengths2, K=K
            )
            out2 = knn_points_idx(
                x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
            )
            self._check_knn_result(out1, out2, sort)

    def test_knn_vs_python_cuda_ragged(self):
        device = torch.device("cuda")
        lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
        lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
        N = 4
        D = 3
        Ks = [1, 9, 10, 11, 101]
        sorts = [True, False]
        versions = [0, 1, 2, 3]
        factors = [Ks, sorts]
        for K, sort in product(*factors):
            x = torch.randn(N, lengths1.max(), D, device=device)
            y = torch.randn(N, lengths2.max(), D, device=device)
            out1 = _knn_points_idx_naive(
                x, y, lengths1=lengths1, lengths2=lengths2, K=K
            )
Justin Johnson's avatar
Justin Johnson committed
113
114
115
            for version in versions:
                if version == 3 and K > 4:
                    continue
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
116
117
118
                out2 = knn_points_idx(
                    x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
                )
Justin Johnson's avatar
Justin Johnson committed
119
                self._check_knn_result(out1, out2, sort)