test_acos_linear_extrapolation.py 5.42 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
David Novotny's avatar
David Novotny committed
6
7
8
9
10
11
12


import unittest

import numpy as np
import torch
from common_testing import TestCaseMixin
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
from pytorch3d.common.compat import lstsq
David Novotny's avatar
David Novotny committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from pytorch3d.transforms import acos_linear_extrapolation


class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)
        np.random.seed(42)

    @staticmethod
    def init_acos_boundary_values(batch_size: int = 10000):
        """
        Initialize a tensor containing values close to the bounds of the
        domain of `acos`, i.e. close to -1 or 1; and random values between (-1, 1).
        """
        device = torch.device("cuda:0")
        # one quarter are random values between -1 and 1
        x_rand = 2 * torch.rand(batch_size // 4, dtype=torch.float32, device=device) - 1
        x = [x_rand]
        for bound in [-1, 1]:
            for above_bound in [True, False]:
                for noise_std in [1e-4, 1e-2]:
                    n_generate = (batch_size - batch_size // 4) // 8
                    x_add = (
                        bound
                        + (2 * float(above_bound) - 1)
                        * torch.randn(
                            n_generate, device=device, dtype=torch.float32
                        ).abs()
                        * noise_std
                    )
                    x.append(x_add)
        x = torch.cat(x)
        return x

    @staticmethod
    def acos_linear_extrapolation(batch_size: int):
        x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
        torch.cuda.synchronize()

        def compute_acos():
            acos_linear_extrapolation(x)
            torch.cuda.synchronize()

        return compute_acos

    def _test_acos_outside_bounds(self, x, y, dydx, bound):
        """
        Check that `acos_linear_extrapolation` yields points on a line with correct
        slope, and that the function is continuous around `bound`.
        """
        bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
        # fit a line: slope * x + bias = y
        x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
        slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
69
        desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2)
David Novotny's avatar
David Novotny committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        # test that the desired slope is the same as the fitted one
        self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
        # test that the autograd's slope is the same as the desired one
        self.assertClose(desired_slope.expand_as(dydx), dydx, atol=1e-2)
        # test that the value of the fitted line at x=bound equals
        # arccos(x), i.e. the function is continuous around the bound
        y_bound_lin = (slope * bound_t + bias).view(1)
        y_bound_acos = bound_t.acos().view(1)
        self.assertClose(y_bound_lin, y_bound_acos, atol=1e-2)

    def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float):
        """
        Test that `acos_linear_extrapolation` returns correct values for
        `x` between/above/below `lower_bound`/`upper_bound`.
        """
        x.requires_grad = True
        x.grad = None
        y = acos_linear_extrapolation(x, [lower_bound, upper_bound])
        # compute the gradient of the acos w.r.t. x
        y.backward(torch.ones_like(y))
        dacos_dx = x.grad
        x_lower = x <= lower_bound
        x_upper = x >= upper_bound
        x_mid = (~x_lower) & (~x_upper)
        # test that between bounds, the function returns plain acos
        self.assertClose(x[x_mid].acos(), y[x_mid])
        # test that outside the bounds, the function is linear with the right
        # slope and continuous around the bound
        self._test_acos_outside_bounds(
            x[x_upper], y[x_upper], dacos_dx[x_upper], upper_bound
        )
        self._test_acos_outside_bounds(
            x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
        )

    def test_acos(self, batch_size: int = 10000):
        """
        Tests whether the function returns correct outputs
        inside/outside the bounds.
        """
        x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
        bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
        for lower_bound in -bounds:
            for upper_bound in bounds:
                if upper_bound < lower_bound:
                    continue
                self._one_acos_test(x, float(lower_bound), float(upper_bound))

    def test_finite_gradient(self, batch_size: int = 10000):
        """
        Tests whether gradients stay finite close to the bounds.
        """
        x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
        x.requires_grad = True
        bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
        for lower_bound in -bounds:
            for upper_bound in bounds:
                if upper_bound < lower_bound:
                    continue
                x.grad = None
                y = acos_linear_extrapolation(
                    x,
                    [float(lower_bound), float(upper_bound)],
                )
                self.assertTrue(torch.isfinite(y).all())
                loss = y.mean()
                loss.backward()
                self.assertIsNotNone(x.grad)
                self.assertTrue(torch.isfinite(x.grad).all())