test_so3.py 11.3 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7


8
import math
facebook-github-bot's avatar
facebook-github-bot committed
9
10
import unittest

11
12
import numpy as np
import torch
13
from common_testing import TestCaseMixin
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
14
from pytorch3d.common.compat import qr
facebook-github-bot's avatar
facebook-github-bot committed
15
16
from pytorch3d.transforms.so3 import (
    hat,
17
    so3_exp_map,
facebook-github-bot's avatar
facebook-github-bot committed
18
19
    so3_log_map,
    so3_relative_angle,
20
    so3_rotation_angle,
facebook-github-bot's avatar
facebook-github-bot committed
21
22
23
)


24
class TestSO3(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28
29
30
31
32
33
34
35
36
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)
        np.random.seed(42)

    @staticmethod
    def init_log_rot(batch_size: int = 10):
        """
        Initialize a list of `batch_size` 3-dimensional vectors representing
        randomly generated logarithms of rotation matrices.
        """
        device = torch.device("cuda:0")
37
        log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
38
39
40
41
42
43
44
45
46
47
48
49
        return log_rot

    @staticmethod
    def init_rot(batch_size: int = 10):
        """
        Randomly generate a batch of `batch_size` 3x3 rotation matrices.
        """
        device = torch.device("cuda:0")

        # TODO(dnovotny): replace with random_rotation from random_rotation.py
        rot = []
        for _ in range(batch_size):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
50
            r = qr(torch.randn((3, 3), device=device))[0]
facebook-github-bot's avatar
facebook-github-bot committed
51
52
53
54
55
56
57
58
59
60
61
            f = torch.randint(2, (3,), device=device, dtype=torch.float32)
            if f.sum() % 2 == 0:
                f = 1 - f
            rot.append(r * (2 * f - 1).float())
        rot = torch.stack(rot)

        return rot

    def test_determinant(self):
        """
        Tests whether the determinants of 3x3 rotation matrices produced
62
        by `so3_exp_map` are (almost) equal to 1.
facebook-github-bot's avatar
facebook-github-bot committed
63
64
        """
        log_rot = TestSO3.init_log_rot(batch_size=30)
65
        Rs = so3_exp_map(log_rot)
66
67
        dets = torch.det(Rs)
        self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
facebook-github-bot's avatar
facebook-github-bot committed
68
69
70
71
72
73
74
75
76
77
78
79

    def test_cross(self):
        """
        For a pair of randomly generated 3-dimensional vectors `a` and `b`,
        tests whether a matrix product of `hat(a)` and `b` equals the result
        of a cross product between `a` and `b`.
        """
        device = torch.device("cuda:0")
        a, b = torch.randn((2, 100, 3), dtype=torch.float32, device=device)
        hat_a = hat(a)
        cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0]
        torch_cross = torch.cross(a, b, dim=1)
80
        self.assertClose(torch_cross, cross, atol=1e-4)
facebook-github-bot's avatar
facebook-github-bot committed
81
82
83

    def test_bad_so3_input_value_err(self):
        """
84
        Tests whether `so3_exp_map` and `so3_log_map` correctly return
facebook-github-bot's avatar
facebook-github-bot committed
85
        a ValueError if called with an argument of incorrect shape or, in case
86
        of `so3_exp_map`, unexpected trace.
facebook-github-bot's avatar
facebook-github-bot committed
87
88
89
90
        """
        device = torch.device("cuda:0")
        log_rot = torch.randn(size=[5, 4], device=device)
        with self.assertRaises(ValueError) as err:
91
            so3_exp_map(log_rot)
92
        self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
facebook-github-bot's avatar
facebook-github-bot committed
93
94
95
96

        rot = torch.randn(size=[5, 3, 5], device=device)
        with self.assertRaises(ValueError) as err:
            so3_log_map(rot)
97
        self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
facebook-github-bot's avatar
facebook-github-bot committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        # trace of rot definitely bigger than 3 or smaller than -1
        rot = torch.cat(
            (
                torch.rand(size=[5, 3, 3], device=device) + 4.0,
                torch.rand(size=[5, 3, 3], device=device) - 3.0,
            )
        )
        with self.assertRaises(ValueError) as err:
            so3_log_map(rot)
        self.assertTrue(
            "A matrix has trace outside valid range [-1-eps,3+eps]."
            in str(err.exception)
        )

    def test_so3_exp_singularity(self, batch_size: int = 100):
        """
115
        Tests whether the `so3_exp_map` is robust to the input vectors
facebook-github-bot's avatar
facebook-github-bot committed
116
117
118
119
120
121
        the norms of which are close to the numerically unstable region
        (vectors with low l2-norms).
        """
        # generate random log-rotations with a tiny angle
        log_rot = TestSO3.init_log_rot(batch_size=batch_size)
        log_rot_small = log_rot * 1e-6
122
123
        log_rot_small.requires_grad = True
        R = so3_exp_map(log_rot_small)
facebook-github-bot's avatar
facebook-github-bot committed
124
        # tests whether all outputs are finite
125
126
127
128
129
130
        self.assertTrue(torch.isfinite(R).all())
        # tests whether the gradient is not None and all finite
        loss = R.sum()
        loss.backward()
        self.assertIsNotNone(log_rot_small.grad)
        self.assertTrue(torch.isfinite(log_rot_small.grad).all())
facebook-github-bot's avatar
facebook-github-bot committed
131
132
133
134
135
136
137
138
139

    def test_so3_log_singularity(self, batch_size: int = 100):
        """
        Tests whether the `so3_log_map` is robust to the input matrices
        who's rotation angles are close to the numerically unstable region
        (i.e. matrices with low rotation angles).
        """
        # generate random rotations with a tiny angle
        device = torch.device("cuda:0")
140
141
142
        identity = torch.eye(3, device=device)
        rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
        r = [identity, rot180]
143
        # add random rotations and random almost orthonormal matrices
144
145
        r.extend(
            [
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
146
                qr(identity + torch.randn_like(identity) * 1e-4)[0]
147
148
149
150
151
                + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
                # this adds random noise to the second half
                # of the random orthogonal matrices to generate
                # near-orthogonal matrices
                for i in range(batch_size - 2)
152
153
154
            ]
        )
        r = torch.stack(r)
155
        r.requires_grad = True
facebook-github-bot's avatar
facebook-github-bot committed
156
        # the log of the rotation matrix r
157
        r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
facebook-github-bot's avatar
facebook-github-bot committed
158
        # tests whether all outputs are finite
159
160
161
162
163
164
        self.assertTrue(torch.isfinite(r_log).all())
        # tests whether the gradient is not None and all finite
        loss = r.sum()
        loss.backward()
        self.assertIsNotNone(r.grad)
        self.assertTrue(torch.isfinite(r.grad).all())
facebook-github-bot's avatar
facebook-github-bot committed
165

166
167
168
    def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
        """
        Check that
169
170
        `so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
        == so3_exp_map(log_rot)`
171
        for a randomly generated batch of rotation matrix logarithms `log_rot`.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
172
173
        Unlike `test_so3_log_to_exp_to_log`, this test checks the
        correctness of converting a `log_rot` which contains values > math.pi.
174
175
        """
        log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
176
177
178
179
180
181
182
        # check also the singular cases where rot. angle = {0, 2pi}
        log_rot[:2] = 0
        log_rot[1, 0] = 2.0 * math.pi - 1e-6
        rot = so3_exp_map(log_rot, eps=1e-4)
        rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
        self.assertClose(rot, rot_, atol=0.01)
        angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
183
184
        self.assertClose(angles, torch.zeros_like(angles), atol=0.01)

facebook-github-bot's avatar
facebook-github-bot committed
185
186
    def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
        """
187
        Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
facebook-github-bot's avatar
facebook-github-bot committed
188
189
190
        a randomly generated batch of rotation matrix logarithms `log_rot`.
        """
        log_rot = TestSO3.init_log_rot(batch_size=batch_size)
191
192
        # check also the singular cases where rot. angle = 0
        log_rot[:1] = 0
193
        log_rot_ = so3_log_map(so3_exp_map(log_rot))
194
        self.assertClose(log_rot, log_rot_, atol=1e-4)
facebook-github-bot's avatar
facebook-github-bot committed
195
196
197

    def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
        """
198
        Check that `so3_exp_map(so3_log_map(R))==R` for
facebook-github-bot's avatar
facebook-github-bot committed
199
200
201
        a batch of randomly generated rotation matrices `R`.
        """
        rot = TestSO3.init_rot(batch_size=batch_size)
202
203
204
205
206
        non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
        rot = rot[non_singular]
        rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
        self.assertClose(rot_, rot, atol=0.1)
        angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
207
        self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
facebook-github-bot's avatar
facebook-github-bot committed
208

209
    def test_so3_cos_relative_angle(self, batch_size: int = 100):
facebook-github-bot's avatar
facebook-github-bot committed
210
211
        """
        Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
212
        is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
facebook-github-bot's avatar
facebook-github-bot committed
213
214
215
216
217
218
        batches of randomly generated rotation matrices `R1` and `R2`.
        """
        rot1 = TestSO3.init_rot(batch_size=batch_size)
        rot2 = TestSO3.init_rot(batch_size=batch_size)
        angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
        angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        self.assertClose(angles, angles_, atol=1e-4)

    def test_so3_cos_angle(self, batch_size: int = 100):
        """
        Check that `so3_rotation_angle(R, cos_angle=False).cos()`
        is the same as `so3_rotation_angle(R, cos_angle=True)` for
        a batch of randomly generated rotation matrices `R`.
        """
        rot = TestSO3.init_rot(batch_size=batch_size)
        angles = so3_rotation_angle(rot, cos_angle=False).cos()
        angles_ = so3_rotation_angle(rot, cos_angle=True)
        self.assertClose(angles, angles_, atol=1e-4)

    def test_so3_cos_bound(self, batch_size: int = 100):
        """
        Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
        non-finite gradients when `cos_bound=None` and finite gradients
        for `cos_bound > 0.0`.
        """
        # generate random rotations with a tiny angle to generate cases
        # with the gradient singularity
        device = torch.device("cuda:0")
        identity = torch.eye(3, device=device)
        rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
        r = [identity, rot180]
        r.extend(
            [
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
246
                qr(identity + torch.randn_like(identity) * 1e-4)[0]
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                for _ in range(batch_size - 2)
            ]
        )
        r = torch.stack(r)
        r.requires_grad = True
        for is_grad_finite in (True, False):
            # clear the gradients and decide the cos_bound:
            #     for is_grad_finite we run so3_rotation_angle with cos_bound
            #     set to a small float, otherwise we set to 0.0
            r.grad = None
            cos_bound = 1e-4 if is_grad_finite else 0.0
            # compute the angles of r
            angles = so3_rotation_angle(r, cos_bound=cos_bound)
            # tests whether all outputs are finite in both cases
            self.assertTrue(torch.isfinite(angles).all())
            # compute the gradients
            loss = angles.sum()
            loss.backward()
            # tests whether the gradient is not None for both cases
            self.assertIsNotNone(r.grad)
            if is_grad_finite:
                # all grad values have to be finite
                self.assertTrue(torch.isfinite(r.grad).all())
facebook-github-bot's avatar
facebook-github-bot committed
270
271
272
273
274
275
276

    @staticmethod
    def so3_expmap(batch_size: int = 10):
        log_rot = TestSO3.init_log_rot(batch_size=batch_size)
        torch.cuda.synchronize()

        def compute_rots():
277
            so3_exp_map(log_rot)
facebook-github-bot's avatar
facebook-github-bot committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            torch.cuda.synchronize()

        return compute_rots

    @staticmethod
    def so3_logmap(batch_size: int = 10):
        log_rot = TestSO3.init_rot(batch_size=batch_size)
        torch.cuda.synchronize()

        def compute_logs():
            so3_log_map(log_rot)
            torch.cuda.synchronize()

        return compute_logs