test_chamfer.py 12.7 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
import torch
import torch.nn.functional as F

from pytorch3d.loss import chamfer_distance

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
from common_testing import TestCaseMixin
facebook-github-bot's avatar
facebook-github-bot committed
10

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
12

class TestChamfer(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    @staticmethod
    def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
        """
        Randomly initialize two batches of point clouds of sizes
        (N, P1, D) and (N, P2, D) and return random normal vectors for
        each batch of size (N, P1, 3) and (N, P2, 3).
        """
        device = torch.device("cuda:0")
        p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
        p1_normals = torch.rand(
            (batch_size, P1, 3), dtype=torch.float32, device=device
        )
        p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
        p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
        p2_normals = torch.rand(
            (batch_size, P2, 3), dtype=torch.float32, device=device
        )
        p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
        weights = torch.rand((batch_size,), dtype=torch.float32, device=device)

        return p1, p2, p1_normals, p2_normals, weights

    @staticmethod
    def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
        """
        Naive iterative implementation of nearest neighbor and chamfer distance.
        Returns lists of the unreduced loss and loss_normals.
        """
        N, P1, D = p1.shape
        P2 = p2.size(1)
        device = torch.device("cuda:0")
        return_normals = p1_normals is not None and p2_normals is not None
        dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)

        for n in range(N):
            for i1 in range(P1):
                for i2 in range(P2):
                    dist[n, i1, i2] = torch.sum(
                        (p1[n, i1, :] - p2[n, i2, :]) ** 2
                    )

        loss = [
            torch.min(dist, dim=2)[0],  # (N, P1)
            torch.min(dist, dim=1)[0],  # (N, P2)
        ]

        lnorm = [p1.new_zeros(()), p1.new_zeros(())]

        if return_normals:
            p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
            p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
            lnorm1 = 1 - torch.abs(
                F.cosine_similarity(
                    p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
                )
            )
            lnorm2 = 1 - torch.abs(
                F.cosine_similarity(
                    p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
                )
            )
            lnorm = [lnorm1, lnorm2]  # [(N, P1), (N, P2)]

        return loss, lnorm

    def test_chamfer_default_no_normals(self):
        """
        Compare chamfer loss with naive implementation using default
        input values and no normals.
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
        pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
        loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
        pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss *= weights
        pred_loss = pred_loss.sum() / weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
90
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self.assertTrue(loss_norm is None)

    def test_chamfer_point_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for point_reduction in ["mean", "sum", "none"] and
        batch_reduction = "none".
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # point_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="none",
            point_reduction="mean",
        )
        pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss_mean *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
120
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
121
122
123
124
125

        pred_loss_norm_mean = (
            pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
        )
        pred_loss_norm_mean *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
126
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
127
128
129
130
131
132
133
134
135
136
137
138
139

        # point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="none",
            point_reduction="sum",
        )
        pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
        pred_loss_sum *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
140
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
141
142
143

        pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
        pred_loss_norm_sum *= weights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
144
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        # Error when point_reduction = "none" and batch_reduction = "none".
        with self.assertRaises(ValueError):
            chamfer_distance(
                p1,
                p2,
                weights=weights,
                batch_reduction="none",
                point_reduction="none",
            )

        # Error when batch_reduction is not in ["none", "mean", "sum"].
        with self.assertRaises(ValueError):
            chamfer_distance(p1, p2, weights=weights, batch_reduction="max")

    def test_chamfer_batch_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for batch_reduction in ["mean", "sum"] and point_reduction = "none".
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # batch_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="none",
        )
        pred_loss[0] *= weights.view(N, 1)
        pred_loss[1] *= weights.view(N, 1)
        pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
187
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
188
189
190
191

        pred_loss_norm[0] *= weights.view(N, 1)
        pred_loss_norm[1] *= weights.view(N, 1)
        pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
192
        self.assertClose(loss_norm, pred_loss_norm)
facebook-github-bot's avatar
facebook-github-bot committed
193
194
195
196
197
198
199
200
201
202
203
204
205

        # batch_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="none",
        )

        pred_loss /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
206
        self.assertClose(loss, pred_loss)
facebook-github-bot's avatar
facebook-github-bot committed
207
208

        pred_loss_norm /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
209
        self.assertClose(loss_norm, pred_loss_norm)
facebook-github-bot's avatar
facebook-github-bot committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        # Error when point_reduction is not in ["none", "mean", "sum"].
        with self.assertRaises(ValueError):
            chamfer_distance(p1, p2, weights=weights, point_reduction="max")

    def test_chamfer_joint_reduction(self):
        """
        Compare output of vectorized chamfer loss with naive implementation
        for batch_reduction in ["mean", "sum"] and
        point_reduction in ["mean", "sum"].
        """
        N, P1, P2 = 7, 10, 18
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            N, P1, P2
        )

        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
            p1, p2, p1_normals, p2_normals
        )

        # batch_reduction = "sum", point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="sum",
        )
        pred_loss[0] *= weights.view(N, 1)
        pred_loss[1] *= weights.view(N, 1)
        pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)  # point sum
        pred_loss_sum = pred_loss_sum.sum()  # batch sum
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
244
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
245
246
247
248
249
250
251

        pred_loss_norm[0] *= weights.view(N, 1)
        pred_loss_norm[1] *= weights.view(N, 1)
        pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(
            1
        )  # point sum.
        pred_loss_norm_sum = pred_loss_norm_sum.sum()  # batch sum
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
252
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
253
254
255
256
257
258
259
260
261
262
263
264

        # batch_reduction = "mean", point_reduction = "sum".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="sum",
        )
        pred_loss_sum /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
265
        self.assertClose(loss, pred_loss_sum)
facebook-github-bot's avatar
facebook-github-bot committed
266
267

        pred_loss_norm_sum /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
268
        self.assertClose(loss_norm, pred_loss_norm_sum)
facebook-github-bot's avatar
facebook-github-bot committed
269
270
271
272
273
274
275
276
277
278
279
280
281

        # batch_reduction = "sum", point_reduction = "mean".
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="sum",
            point_reduction="mean",
        )
        pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
        pred_loss_mean = pred_loss_mean.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
282
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
283
284
285
286
287

        pred_loss_norm_mean = (
            pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
        )
        pred_loss_norm_mean = pred_loss_norm_mean.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
288
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
289
290
291
292
293
294
295
296
297
298
299
300

        # batch_reduction = "mean", point_reduction = "mean". This is the default.
        loss, loss_norm = chamfer_distance(
            p1,
            p2,
            p1_normals,
            p2_normals,
            weights=weights,
            batch_reduction="mean",
            point_reduction="mean",
        )
        pred_loss_mean /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
301
        self.assertClose(loss, pred_loss_mean)
facebook-github-bot's avatar
facebook-github-bot committed
302
303

        pred_loss_norm_mean /= weights.sum()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
304
        self.assertClose(loss_norm, pred_loss_norm_mean)
facebook-github-bot's avatar
facebook-github-bot committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    def test_incorrect_weights(self):
        N, P1, P2 = 16, 64, 128
        device = torch.device("cuda:0")
        p1 = torch.rand(
            (N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
        )
        p2 = torch.rand(
            (N, P2, 3), dtype=torch.float32, device=device, requires_grad=True
        )

        weights = torch.zeros((N,), dtype=torch.float32, device=device)
        loss, loss_norm = chamfer_distance(
            p1, p2, weights=weights, batch_reduction="mean"
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
320
        self.assertClose(loss.cpu(), torch.zeros(()))
facebook-github-bot's avatar
facebook-github-bot committed
321
        self.assertTrue(loss.requires_grad)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
322
        self.assertClose(loss_norm.cpu(), torch.zeros(()))
facebook-github-bot's avatar
facebook-github-bot committed
323
324
325
326
327
        self.assertTrue(loss_norm.requires_grad)

        loss, loss_norm = chamfer_distance(
            p1, p2, weights=weights, batch_reduction="none"
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
328
        self.assertClose(loss.cpu(), torch.zeros((N, N)))
facebook-github-bot's avatar
facebook-github-bot committed
329
        self.assertTrue(loss.requires_grad)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
330
        self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
facebook-github-bot's avatar
facebook-github-bot committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        self.assertTrue(loss_norm.requires_grad)

        weights = torch.ones((N,), dtype=torch.float32, device=device) * -1
        with self.assertRaises(ValueError):
            loss, loss_norm = chamfer_distance(p1, p2, weights=weights)

        weights = torch.zeros((N - 1,), dtype=torch.float32, device=device)
        with self.assertRaises(ValueError):
            loss, loss_norm = chamfer_distance(p1, p2, weights=weights)

    @staticmethod
    def chamfer_with_init(
        batch_size: int, P1: int, P2: int, return_normals: bool
    ):
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            batch_size, P1, P2
        )
        torch.cuda.synchronize()

        def loss():
            loss, loss_normals = chamfer_distance(
                p1, p2, p1_normals, p2_normals, weights=weights
            )
            torch.cuda.synchronize()

        return loss

    @staticmethod
    def chamfer_naive_with_init(
        batch_size: int, P1: int, P2: int, return_normals: bool
    ):
        p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
            batch_size, P1, P2
        )
        torch.cuda.synchronize()

        def loss():
            loss, loss_normals = TestChamfer.chamfer_distance_naive(
                p1, p2, p1_normals, p2_normals
            )
            torch.cuda.synchronize()

        return loss