test_chamfer.py 12.6 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
4

facebook-github-bot's avatar
facebook-github-bot committed
5
6
import torch
import torch.nn.functional as F
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
7
from common_testing import TestCaseMixin
8
from pytorch3d.loss import chamfer_distance
facebook-github-bot's avatar
facebook-github-bot committed
9

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
10
11

class TestChamfer(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
12
13
14
15
16
17
18
19
20
    @staticmethod
    def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
        """
        Randomly initialize two batches of point clouds of sizes
        (N, P1, D) and (N, P2, D) and return random normal vectors for
        each batch of size (N, P1, 3) and (N, P2, 3).
        """
        device = torch.device("cuda:0")
        p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
21
        p1_normals = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
22
23
        p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
        p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
24
        p2_normals = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
        weights = torch.rand((batch_size,), dtype=torch.float32, device=device)

        return p1, p2, p1_normals, p2_normals, weights

    @staticmethod
    def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
        """
        Naive iterative implementation of nearest neighbor and chamfer distance.
        Returns lists of the unreduced loss and loss_normals.
        """
        N, P1, D = p1.shape
        P2 = p2.size(1)
        device = torch.device("cuda:0")
        return_normals = p1_normals is not None and p2_normals is not None
        dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)

        for n in range(N):
            for i1 in range(P1):
                for i2 in range(P2):
45
                    dist[n, i1, i2] = torch.sum((p1[n, i1, :] - p2[n, i2, :]) ** 2)
facebook-github-bot's avatar
facebook-github-bot committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        loss = [
            torch.min(dist, dim=2)[0],  # (N, P1)
            torch.min(dist, dim=1)[0],  # (N, P2)
        ]

        lnorm = [p1.new_zeros(()), p1.new_zeros(())]

        if return_normals:
            p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
            p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
            lnorm1 = 1 - torch.abs(
                F.cosine_similarity(
                    p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
                )
            )
            lnorm2 = 1 - torch.abs(
                F.cosine_similarity(
                    p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
                )
            )
            lnorm = [lnorm1, lnorm2]  # [(N, P1), (N, P2)]

        return loss, lnorm

    def test_chamfer_default_no_normals(self):
        """
        Compare chamfer loss with naive implementation using default
        input values and no normals.
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
        pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
        loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
        pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss *= weights
        pred_loss = pred_loss.sum() / weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
83
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.assertTrue(loss_norm is None)

    def test_chamfer_point_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for point_reduction in ["mean", "sum", "none"] and
        batch_reduction = "none".
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # point_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="none",
            point_reduction="mean",
        )
        pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss_mean *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
113
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
114
115
116
117
118

        pred_loss_norm_mean = (
            pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
        )
        pred_loss_norm_mean *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
119
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
120
121
122
123
124
125
126
127
128
129
130
131
132

        # point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="none",
            point_reduction="sum",
        )
        pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
        pred_loss_sum *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
133
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
134
135
136

        pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
        pred_loss_norm_sum *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
137
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
138
139
140
141

        # Error when point_reduction = "none" and batch_reduction = "none".
        with self.assertRaises(ValueError):
            chamfer_distance(
142
                p1, p2, weights=weights, batch_reduction="none", point_reduction="none"
facebook-github-bot's avatar
facebook-github-bot committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            )

        # Error when batch_reduction is not in ["none", "mean", "sum"].
        with self.assertRaises(ValueError):
            chamfer_distance(p1, p2, weights=weights, batch_reduction="max")

    def test_chamfer_batch_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for batch_reduction in ["mean", "sum"] and point_reduction = "none".
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # batch_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="none",
        )
        pred_loss[0] *= weights.view(N, 1)
        pred_loss[1] *= weights.view(N, 1)
        pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
176
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
177
178
179
180

        pred_loss_norm[0] *= weights.view(N, 1)
        pred_loss_norm[1] *= weights.view(N, 1)
        pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
181
        self.assertClose(loss_norm, pred_loss_norm)
facebook-github-bot's avatar
facebook-github-bot committed
182
183
184
185
186
187
188
189
190
191
192
193
194

        # batch_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="none",
        )

        pred_loss /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
195
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
196
197

        pred_loss_norm /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
198
        self.assertClose(loss_norm, pred_loss_norm)
facebook-github-bot's avatar
facebook-github-bot committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        # Error when point_reduction is not in ["none", "mean", "sum"].
        with self.assertRaises(ValueError):
            chamfer_distance(p1, p2, weights=weights, point_reduction="max")

    def test_chamfer_joint_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for batch_reduction in ["mean", "sum"] and
        point_reduction in ["mean", "sum"].
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # batch_reduction = "sum", point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="sum",
        )
        pred_loss[0] *= weights.view(N, 1)
        pred_loss[1] *= weights.view(N, 1)
        pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)  # point sum
        pred_loss_sum = pred_loss_sum.sum()  # batch sum
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
233
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
234
235
236
237
238
239
240

        pred_loss_norm[0] *= weights.view(N, 1)
        pred_loss_norm[1] *= weights.view(N, 1)
        pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(
            1
        )  # point sum.
        pred_loss_norm_sum = pred_loss_norm_sum.sum()  # batch sum
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
241
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
242
243
244
245
246
247
248
249
250
251
252
253

        # batch_reduction = "mean", point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="sum",
        )
        pred_loss_sum /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
254
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
255
256

        pred_loss_norm_sum /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
257
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
258
259
260
261
262
263
264
265
266
267
268
269
270

        # batch_reduction = "sum", point_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="mean",
        )
        pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss_mean = pred_loss_mean.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
271
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
272
273
274
275
276

        pred_loss_norm_mean = (
            pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
        )
        pred_loss_norm_mean = pred_loss_norm_mean.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
277
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
278
279
280
281
282
283
284
285
286
287
288
289

        # batch_reduction = "mean", point_reduction = "mean". This is the default.
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="mean",
        )
        pred_loss_mean /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
290
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
291
292

        pred_loss_norm_mean /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
293
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

    def test_incorrect_weights(self):
        N, P1, P2 = 16, 64, 128
        device = torch.device("cuda:0")
        p1 = torch.rand(
            (N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
        )
        p2 = torch.rand(
            (N, P2, 3), dtype=torch.float32, device=device, requires_grad=True
        )

        weights = torch.zeros((N,), dtype=torch.float32, device=device)
        loss, loss_norm = chamfer_distance(
            p1, p2, weights=weights, batch_reduction="mean"
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
309
        self.assertClose(loss.cpu(), torch.zeros(()))
facebook-github-bot's avatar
facebook-github-bot committed
310
        self.assertTrue(loss.requires_grad)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
311
        self.assertClose(loss_norm.cpu(), torch.zeros(()))
facebook-github-bot's avatar
facebook-github-bot committed
312
313
314
315
316
        self.assertTrue(loss_norm.requires_grad)

        loss, loss_norm = chamfer_distance(
            p1, p2, weights=weights, batch_reduction="none"
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
317
        self.assertClose(loss.cpu(), torch.zeros((N, N)))
facebook-github-bot's avatar
facebook-github-bot committed
318
        self.assertTrue(loss.requires_grad)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
319
        self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
facebook-github-bot's avatar
facebook-github-bot committed
320
321
322
323
324
325
326
327
328
329
330
        self.assertTrue(loss_norm.requires_grad)

        weights = torch.ones((N,), dtype=torch.float32, device=device) * -1
        with self.assertRaises(ValueError):
            loss, loss_norm = chamfer_distance(p1, p2, weights=weights)

        weights = torch.zeros((N - 1,), dtype=torch.float32, device=device)
        with self.assertRaises(ValueError):
            loss, loss_norm = chamfer_distance(p1, p2, weights=weights)

    @staticmethod
331
    def chamfer_with_init(batch_size: int, P1: int, P2: int, return_normals: bool):
facebook-github-bot's avatar
facebook-github-bot committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            batch_size, P1, P2
        )
        torch.cuda.synchronize()

        def loss():
            loss, loss_normals = chamfer_distance(
                p1, p2, p1_normals, p2_normals, weights=weights
            )
            torch.cuda.synchronize()

        return loss

    @staticmethod
    def chamfer_naive_with_init(
        batch_size: int, P1: int, P2: int, return_normals: bool
    ):
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            batch_size, P1, P2
        )
        torch.cuda.synchronize()

        def loss():
            loss, loss_normals = TestChamfer.chamfer_distance_naive(
                p1, p2, p1_normals, p2_normals
            )
            torch.cuda.synchronize()

        return loss