cameras.py 60.2 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7

import math
Georgia Gkioxari's avatar
Georgia Gkioxari committed
8
import warnings
9
from typing import Optional, Sequence, Tuple, Union, List
10
11

import numpy as np
facebook-github-bot's avatar
facebook-github-bot committed
12
13
import torch
import torch.nn.functional as F
14
from pytorch3d.common.types import Device
facebook-github-bot's avatar
facebook-github-bot committed
15
16
17
18
from pytorch3d.transforms import Rotate, Transform3d, Translate

from .utils import TensorProperties, convert_to_tensors_and_broadcast

19

facebook-github-bot's avatar
facebook-github-bot committed
20
# Default values for rotation and translation matrices.
David Novotny's avatar
David Novotny committed
21
22
_R = torch.eye(3)[None]  # (1, 3, 3)
_T = torch.zeros(1, 3)  # (1, 3)
facebook-github-bot's avatar
facebook-github-bot committed
23
24


25
26
27
28
class CamerasBase(TensorProperties):
    """
    `CamerasBase` implements a base class for all cameras.

Georgia Gkioxari's avatar
Georgia Gkioxari committed
29
30
    For cameras, there are four different coordinate systems (or spaces)
    - World coordinate system: This is the system the object lives - the world.
31
    - Camera view coordinate system: This is the system that has its origin on the camera
Georgia Gkioxari's avatar
Georgia Gkioxari committed
32
33
34
        and the and the Z-axis perpendicular to the image plane.
        In PyTorch3D, we assume that +X points left, and +Y points up and
        +Z points out from the image plane.
35
        The transformation from world --> view happens after applying a rotation (R)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
36
37
        and translation (T)
    - NDC coordinate system: This is the normalized coordinate system that confines
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
38
        in a volume the rendered part of the object or scene. Also known as view volume.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
39
40
        Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner,
        and (-1, -1, zfar) is the bottom right far corner of the volume.
41
42
        The transformation from view --> NDC happens after applying the camera
        projection matrix (P) if defined in NDC space.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
43
    - Screen coordinate system: This is another representation of the view volume with
44
        the XY coordinates defined in image space instead of a normalized space.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
45

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
46
47
    A better illustration of the coordinate systems can be found in
    pytorch3d/docs/notes/cameras.md.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
48

49
50
    It defines methods that are common to all camera models:
        - `get_camera_center` that returns the optical center of the camera in
51
            world coordinates
52
        - `get_world_to_view_transform` which returns a 3D transform from
53
            world coordinates to the camera view coordinates (R, T)
54
        - `get_full_projection_transform` which composes the projection
55
            transform (P) with the world-to-view transform (R, T)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
56
        - `transform_points` which takes a set of input points in world coordinates and
57
58
59
60
61
62
63
            projects to the space the camera is defined in (NDC or screen)
        - `get_ndc_camera_transform` which defines the transform from screen/NDC to
            PyTorch3D's NDC space
        - `transform_points_ndc` which takes a set of points in world coordinates and
            projects them to PyTorch3D's NDC space
        - `transform_points_screen` which takes a set of points in world coordinates and
            projects them to screen space
64
65

    For each new camera, one should implement the `get_projection_transform`
66
67
    routine that returns the mapping from camera view coordinates to camera
    coordinates (NDC or screen).
68
69

    Another useful function that is specific to each camera model is
70
71
    `unproject_points` which sends points from camera coordinates (NDC or screen)
    back to camera view or world coordinates depending on the `world_coordinates`
72
73
74
75
76
77
78
79
80
81
82
83
    boolean argument of the function.
    """

    def get_projection_transform(self):
        """
        Calculate the projective transformation matrix.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in `__init__`.

        Return:
84
            a `Transform3d` object which represents a batch of projection
85
86
87
88
89
90
            matrices of shape (N, 3, 3)
        """
        raise NotImplementedError()

    def unproject_points(self):
        """
91
        Transform input points from camera coodinates (NDC or screen)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        to the world / camera coordinates.

        Each of the input points `xy_depth` of shape (..., 3) is
        a concatenation of the x, y location and its depth.

        For instance, for an input 2D tensor of shape `(num_points, 3)`
        `xy_depth` takes the following form:
            `xy_depth[i] = [x[i], y[i], depth[i]]`,
        for a each point at an index `i`.

        The following example demonstrates the relationship between
        `transform_points` and `unproject_points`:

        .. code-block:: python

            cameras = # camera object derived from CamerasBase
            xyz = # 3D points of shape (batch_size, num_points, 3)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
109
            # transform xyz to the camera view coordinates
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz)
            # extract the depth of each point as the 3rd coord of xyz_cam
            depth = xyz_cam[:, :, 2:]
            # project the points xyz to the camera
            xy = cameras.transform_points(xyz)[:, :, :2]
            # append depth to xy
            xy_depth = torch.cat((xy, depth), dim=2)
            # unproject to the world coordinates
            xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True)
            print(torch.allclose(xyz, xyz_unproj_world)) # True
            # unproject to the camera coordinates
            xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False)
            print(torch.allclose(xyz_cam, xyz_unproj)) # True

        Args:
            xy_depth: torch tensor of shape (..., 3).
            world_coordinates: If `True`, unprojects the points back to world
                coordinates using the camera extrinsics `R` and `T`.
                `False` ignores `R` and `T` and unprojects to
Georgia Gkioxari's avatar
Georgia Gkioxari committed
129
                the camera view coordinates.
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        Returns
            new_points: unprojected points with the same shape as `xy_depth`.
        """
        raise NotImplementedError()

    def get_camera_center(self, **kwargs) -> torch.Tensor:
        """
        Return the 3D location of the camera optical center
        in the world coordinates.

        Args:
            **kwargs: parameters for the camera extrinsics can be passed in
                as keyword arguments to override the default values
                set in __init__.

        Setting T here will update the values set in init as this
        value may be needed later on in the rendering pipeline e.g. for
        lighting calculations.

        Returns:
            C: a batch of 3D locations of shape (N, 3) denoting
            the locations of the center of each camera in the batch.
        """
        w2v_trans = self.get_world_to_view_transform(**kwargs)
        P = w2v_trans.inverse().get_matrix()
        # the camera center is the translation component (the first 3 elements
        # of the last row) of the inverted world-to-view
        # transform (4x4 RT matrix)
        C = P[:, 3, :3]
        return C

    def get_world_to_view_transform(self, **kwargs) -> Transform3d:
        """
        Return the world-to-view transform.

        Args:
            **kwargs: parameters for the camera extrinsics can be passed in
                as keyword arguments to override the default values
                set in __init__.

        Setting R and T here will update the values set in init as these
        values may be needed later on in the rendering pipeline e.g. for
        lighting calculations.

        Returns:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
176
            A Transform3d object which represents a batch of transforms
177
178
            of shape (N, 3, 3)
        """
Patrick Labatut's avatar
Patrick Labatut committed
179
180
181
182
183
        R: torch.Tensor = kwargs.get("R", self.R)
        T: torch.Tensor = kwargs.get("T", self.T)
        self.R = R  # pyre-ignore[16]
        self.T = T  # pyre-ignore[16]
        world_to_view_transform = get_world_to_view_transform(R=R, T=T)
184
185
186
187
        return world_to_view_transform

    def get_full_projection_transform(self, **kwargs) -> Transform3d:
        """
188
189
190
191
        Return the full world-to-camera transform composing the
        world-to-view and view-to-camera transforms.
        If camera is defined in NDC space, the projected points are in NDC space.
        If camera is defined in screen space, the projected points are in screen space.
192
193
194
195
196
197
198
199
200
201
202

        Args:
            **kwargs: parameters for the projection transforms can be passed in
                as keyword arguments to override the default values
                set in __init__.

        Setting R and T here will update the values set in init as these
        values may be needed later on in the rendering pipeline e.g. for
        lighting calculations.

        Returns:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
203
            a Transform3d object which represents a batch of transforms
204
205
            of shape (N, 3, 3)
        """
Patrick Labatut's avatar
Patrick Labatut committed
206
207
        self.R: torch.Tensor = kwargs.get("R", self.R)  # pyre-ignore[16]
        self.T: torch.Tensor = kwargs.get("T", self.T)  # pyre-ignore[16]
208
        world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
209
210
        view_to_proj_transform = self.get_projection_transform(**kwargs)
        return world_to_view_transform.compose(view_to_proj_transform)
211
212
213
214
215

    def transform_points(
        self, points, eps: Optional[float] = None, **kwargs
    ) -> torch.Tensor:
        """
216
217
218
219
220
221
        Transform input points from world to camera space with the
        projection matrix defined by the camera.

        For `CamerasBase.transform_points`, setting `eps > 0`
        stabilizes gradients since it leads to avoiding division
        by excessively low numbers for points close to the camera plane.
222
223
224
225
226

        Args:
            points: torch tensor of shape (..., 3).
            eps: If eps!=None, the argument is used to clamp the
                divisor in the homogeneous normalization of the points
Georgia Gkioxari's avatar
Georgia Gkioxari committed
227
                transformed to the ndc space. Please see
228
229
230
231
                `transforms.Transform3D.transform_points` for details.

                For `CamerasBase.transform_points`, setting `eps > 0`
                stabilizes gradients since it leads to avoiding division
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
232
                by excessively low numbers for points close to the
233
234
235
236
237
                camera plane.

        Returns
            new_points: transformed points with the same shape as the input.
        """
238
239
        world_to_proj_transform = self.get_full_projection_transform(**kwargs)
        return world_to_proj_transform.transform_points(points, eps=eps)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
        """
        Returns the transform from camera projection space (screen or NDC) to NDC space.
        For cameras that can be specified in screen space, this transform
        allows points to be converted from screen to NDC space.
        The default transform scales the points from [0, W-1]x[0, H-1] to [-1, 1].
        This function should be modified per camera definitions if need be,
        e.g. for Perspective/Orthographic cameras we provide a custom implementation.
        This transform assumes PyTorch3D coordinate system conventions for
        both the NDC space and the input points.

        This transform interfaces with the PyTorch3D renderer which assumes
        input points to the renderer to be in NDC space.
        """
        if self.in_ndc():
            return Transform3d(device=self.device, dtype=torch.float32)
        else:
            # For custom cameras which can be defined in screen space,
            # users might might have to implement the screen to NDC transform based
            # on the definition of the camera parameters.
            # See PerspectiveCameras/OrthographicCameras for an example.
            # We don't flip xy because we assume that world points are in PyTorch3D coodrinates
            # and thus conversion from screen to ndc is a mere scaling from image to [-1, 1] scale.
            return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)

    def transform_points_ndc(
        self, points, eps: Optional[float] = None, **kwargs
Georgia Gkioxari's avatar
Georgia Gkioxari committed
268
269
    ) -> torch.Tensor:
        """
270
271
272
        Transforms points from PyTorch3D world/camera space to NDC space.
        Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
        Output points are in NDC space: +X left, +Y up, origin at image center.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
273
274

        Args:
275
            points: torch tensor of shape (..., 3).
Georgia Gkioxari's avatar
Georgia Gkioxari committed
276
277
278
279
280
281
282
            eps: If eps!=None, the argument is used to clamp the
                divisor in the homogeneous normalization of the points
                transformed to the ndc space. Please see
                `transforms.Transform3D.transform_points` for details.

                For `CamerasBase.transform_points`, setting `eps > 0`
                stabilizes gradients since it leads to avoiding division
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
283
                by excessively low numbers for points close to the
Georgia Gkioxari's avatar
Georgia Gkioxari committed
284
285
286
287
288
                camera plane.

        Returns
            new_points: transformed points with the same shape as the input.
        """
289
290
291
292
        world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
        if not self.in_ndc():
            to_ndc_transform = self.get_ndc_camera_transform(**kwargs)
            world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
293

294
        return world_to_ndc_transform.transform_points(points, eps=eps)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
295

296
297
298
299
300
301
302
    def transform_points_screen(
        self, points, eps: Optional[float] = None, **kwargs
    ) -> torch.Tensor:
        """
        Transforms points from PyTorch3D world/camera space to screen space.
        Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
        Output points are in screen space: +X right, +Y down, origin at top left corner.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
303

304
305
306
307
308
309
        Args:
            points: torch tensor of shape (..., 3).
            eps: If eps!=None, the argument is used to clamp the
                divisor in the homogeneous normalization of the points
                transformed to the ndc space. Please see
                `transforms.Transform3D.transform_points` for details.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
310

311
312
313
314
                For `CamerasBase.transform_points`, setting `eps > 0`
                stabilizes gradients since it leads to avoiding division
                by excessively low numbers for points close to the
                camera plane.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
315

316
317
318
319
320
321
322
        Returns
            new_points: transformed points with the same shape as the input.
        """
        points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs)
        return get_ndc_to_screen_transform(
            self, with_xyflip=True, **kwargs
        ).transform_points(points_ndc, eps=eps)
323
324
325
326
327
328
329
330
331

    def clone(self):
        """
        Returns a copy of `self`.
        """
        cam_type = type(self)
        other = cam_type(device=self.device)
        return super().clone(other)

332
333
334
    def is_perspective(self):
        raise NotImplementedError()

335
336
337
338
339
340
341
    def in_ndc(self):
        """
        Specifies whether the camera is defined in NDC space
        or in screen (image) space
        """
        raise NotImplementedError()

342
343
344
    def get_znear(self):
        return self.znear if hasattr(self, "znear") else None

345
346
347
348
349
350
351
    def get_image_size(self):
        """
        Returns the image size, if provided, expected in the form of (height, width)
        The image size is used for conversion of projected points to screen coordinates.
        """
        return self.image_size if hasattr(self, "image_size") else None

352

Georgia Gkioxari's avatar
Georgia Gkioxari committed
353
354
355
356
############################################################
#             Field of View Camera Classes                 #
############################################################

357

Georgia Gkioxari's avatar
Georgia Gkioxari committed
358
359
360
361
362
363
def OpenGLPerspectiveCameras(
    znear=1.0,
    zfar=100.0,
    aspect_ratio=1.0,
    fov=60.0,
    degrees: bool = True,
Patrick Labatut's avatar
Patrick Labatut committed
364
365
    R: torch.Tensor = _R,
    T: torch.Tensor = _T,
366
    device: Device = "cpu",
Patrick Labatut's avatar
Patrick Labatut committed
367
) -> "FoVPerspectiveCameras":
Georgia Gkioxari's avatar
Georgia Gkioxari committed
368
369
370
371
    """
    OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
    Preserving OpenGLPerspectiveCameras for backward compatibility.
    """
372

Georgia Gkioxari's avatar
Georgia Gkioxari committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    warnings.warn(
        """OpenGLPerspectiveCameras is deprecated,
        Use FoVPerspectiveCameras instead.
        OpenGLPerspectiveCameras will be removed in future releases.""",
        PendingDeprecationWarning,
    )

    return FoVPerspectiveCameras(
        znear=znear,
        zfar=zfar,
        aspect_ratio=aspect_ratio,
        fov=fov,
        degrees=degrees,
        R=R,
        T=T,
        device=device,
    )


class FoVPerspectiveCameras(CamerasBase):
facebook-github-bot's avatar
facebook-github-bot committed
393
394
    """
    A class which stores a batch of parameters to generate a batch of
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
395
    projection matrices by specifying the field of view.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
396
    The definition of the parameters follow the OpenGL perspective camera.
facebook-github-bot's avatar
facebook-github-bot committed
397
398
399

    The extrinsics of the camera (R and T matrices) can also be set in the
    initializer or passed in to `get_full_projection_transform` to get
Georgia Gkioxari's avatar
Georgia Gkioxari committed
400
    the full transformation from world -> ndc.
facebook-github-bot's avatar
facebook-github-bot committed
401

Georgia Gkioxari's avatar
Georgia Gkioxari committed
402
    The `transform_points` method calculates the full world -> ndc transform
facebook-github-bot's avatar
facebook-github-bot committed
403
404
405
    and then applies it to the input points.

    The transforms can also be returned separately as Transform3d objects.
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

    * Setting the Aspect Ratio for Non Square Images *

    If the desired output image size is non square (i.e. a tuple of (H, W) where H != W)
    the aspect ratio needs special consideration: There are two aspect ratios
    to be aware of:
        - the aspect ratio of each pixel
        - the aspect ratio of the output image
    The `aspect_ratio` setting in the FoVPerspectiveCameras sets the
    pixel aspect ratio. When using this camera with the differentiable rasterizer
    be aware that in the rasterizer we assume square pixels, but allow
    variable image aspect ratio (i.e rectangle images).

    In most cases you will want to set the camera `aspect_ratio=1.0`
    (i.e. square pixels) and only vary the output image dimensions in pixels
    for rasterization.
facebook-github-bot's avatar
facebook-github-bot committed
422
423
424
425
426
427
428
429
430
    """

    def __init__(
        self,
        znear=1.0,
        zfar=100.0,
        aspect_ratio=1.0,
        fov=60.0,
        degrees: bool = True,
Patrick Labatut's avatar
Patrick Labatut committed
431
432
433
        R: torch.Tensor = _R,
        T: torch.Tensor = _T,
        K: Optional[torch.Tensor] = None,
434
        device: Device = "cpu",
Patrick Labatut's avatar
Patrick Labatut committed
435
    ) -> None:
facebook-github-bot's avatar
facebook-github-bot committed
436
437
438
439
440
        """

        Args:
            znear: near clipping plane of the view frustrum.
            zfar: far clipping plane of the view frustrum.
441
442
            aspect_ratio: aspect ratio of the image pixels.
                1.0 indicates square pixels.
facebook-github-bot's avatar
facebook-github-bot committed
443
444
445
446
            fov: field of view angle of the camera.
            degrees: bool, set to True if fov is specified in degrees.
            R: Rotation matrix of shape (N, 3, 3)
            T: Translation matrix of shape (N, 3)
447
448
            K: (optional) A calibration matrix of shape (N, 4, 4)
                If provided, don't need znear, zfar, fov, aspect_ratio, degrees
449
            device: Device (as str or torch.device)
facebook-github-bot's avatar
facebook-github-bot committed
450
451
452
453
454
455
456
457
458
459
460
        """
        # The initializer formats all inputs to torch tensors and broadcasts
        # all the inputs to have the same batch dimension where necessary.
        super().__init__(
            device=device,
            znear=znear,
            zfar=zfar,
            aspect_ratio=aspect_ratio,
            fov=fov,
            R=R,
            T=T,
461
            K=K,
facebook-github-bot's avatar
facebook-github-bot committed
462
463
464
465
466
        )

        # No need to convert to tensor or broadcast.
        self.degrees = degrees

467
    def compute_projection_matrix(
Patrick Labatut's avatar
Patrick Labatut committed
468
        self, znear, zfar, fov, aspect_ratio, degrees: bool
469
470
471
472
473
474
475
476
    ) -> torch.Tensor:
        """
        Compute the calibration matrix K of shape (N, 4, 4)

        Args:
            znear: near clipping plane of the view frustrum.
            zfar: far clipping plane of the view frustrum.
            fov: field of view angle of the camera.
477
478
            aspect_ratio: aspect ratio of the image pixels.
                1.0 indicates square pixels.
479
480
481
            degrees: bool, set to True if fov is specified in degrees.

        Returns:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
482
            torch.FloatTensor of the calibration matrix with shape (N, 4, 4)
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        """
        K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        if degrees:
            fov = (np.pi / 180) * fov

        if not torch.is_tensor(fov):
            fov = torch.tensor(fov, device=self.device)
        tanHalfFov = torch.tan((fov / 2))
        max_y = tanHalfFov * znear
        min_y = -max_y
        max_x = max_y * aspect_ratio
        min_x = -max_x

        # NOTE: In OpenGL the projection matrix changes the handedness of the
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
498
        # coordinate frame. i.e the NDC space positive z direction is the
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        # camera space negative z direction. This is because the sign of the z
        # in the projection matrix is set to -1.0.
        # In pytorch3d we maintain a right handed coordinate system throughout
        # so the so the z sign is 1.0.
        z_sign = 1.0

        K[:, 0, 0] = 2.0 * znear / (max_x - min_x)
        K[:, 1, 1] = 2.0 * znear / (max_y - min_y)
        K[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
        K[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
        K[:, 3, 2] = z_sign * ones

        # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
        # is at the near clipping plane and z = 1 when the point is at the far
        # clipping plane.
        K[:, 2, 2] = z_sign * zfar / (zfar - znear)
        K[:, 2, 3] = -(zfar * znear) / (zfar - znear)

        return K

facebook-github-bot's avatar
facebook-github-bot committed
519
520
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
521
        Calculate the perspective projection matrix with a symmetric
facebook-github-bot's avatar
facebook-github-bot committed
522
        viewing frustrum. Use column major order.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
523
524
525
        The viewing frustrum will be projected into ndc, s.t.
        (max_x, max_y) -> (+1, +1)
        (min_x, min_y) -> (-1, -1)
facebook-github-bot's avatar
facebook-github-bot committed
526
527
528
529
530
531

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in `__init__`.

        Return:
532
            a Transform3d object which represents a batch of projection
Georgia Gkioxari's avatar
Georgia Gkioxari committed
533
            matrices of shape (N, 4, 4)
facebook-github-bot's avatar
facebook-github-bot committed
534
535
536

        .. code-block:: python

Georgia Gkioxari's avatar
Georgia Gkioxari committed
537
538
            h1 = (max_y + min_y)/(max_y - min_y)
            w1 = (max_x + min_x)/(max_x - min_x)
facebook-github-bot's avatar
facebook-github-bot committed
539
540
541
542
            tanhalffov = tan((fov/2))
            s1 = 1/tanhalffov
            s2 = 1/(tanhalffov * (aspect_ratio))

543
544
545
546
547
            # To map z to the range [0, 1] use:
            f1 =  far / (far - near)
            f2 = -(far * near) / (far - near)

            # Projection matrix
548
            K = [
facebook-github-bot's avatar
facebook-github-bot committed
549
550
551
                    [s1,   0,   w1,   0],
                    [0,   s2,   h1,   0],
                    [0,    0,   f1,  f2],
552
                    [0,    0,    1,   0],
facebook-github-bot's avatar
facebook-github-bot committed
553
554
            ]
        """
555
        K = kwargs.get("K", self.K)
556
557
558
559
560
561
        if K is not None:
            if K.shape != (self._N, 4, 4):
                msg = "Expected K to have shape of (%r, 4, 4)"
                raise ValueError(msg % (self._N))
        else:
            K = self.compute_projection_matrix(
562
563
564
565
                kwargs.get("znear", self.znear),
                kwargs.get("zfar", self.zfar),
                kwargs.get("fov", self.fov),
                kwargs.get("aspect_ratio", self.aspect_ratio),
566
567
                kwargs.get("degrees", self.degrees),
            )
facebook-github-bot's avatar
facebook-github-bot committed
568

David Novotny's avatar
David Novotny committed
569
        # Transpose the projection matrix as PyTorch3D transforms use row vectors.
570
571
572
        transform = Transform3d(
            matrix=K.transpose(1, 2).contiguous(), device=self.device
        )
facebook-github-bot's avatar
facebook-github-bot committed
573
574
        return transform

575
576
577
578
579
580
581
582
    def unproject_points(
        self,
        xy_depth: torch.Tensor,
        world_coordinates: bool = True,
        scaled_depth_input: bool = False,
        **kwargs
    ) -> torch.Tensor:
        """>!
Georgia Gkioxari's avatar
Georgia Gkioxari committed
583
        FoV cameras further allow for passing depth in world units
584
585
        (`scaled_depth_input=False`) or in the [0, 1]-normalized units
        (`scaled_depth_input=True`)
facebook-github-bot's avatar
facebook-github-bot committed
586
587

        Args:
588
589
590
591
592
            scaled_depth_input: If `True`, assumes the input depth is in
                the [0, 1]-normalized units. If `False` the input depth is in
                the world units.
        """

Georgia Gkioxari's avatar
Georgia Gkioxari committed
593
        # obtain the relevant transformation to ndc
594
        if world_coordinates:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
595
            to_ndc_transform = self.get_full_projection_transform()
596
        else:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
597
            to_ndc_transform = self.get_projection_transform()
598
599
600
601
602
603

        if scaled_depth_input:
            # the input is scaled depth, so we don't have to do anything
            xy_sdepth = xy_depth
        else:
            # parse out important values from the projection matrix
604
605
            K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
            # parse out f1, f2 from K_matrix
606
            unsqueeze_shape = [1] * xy_depth.dim()
607
608
609
            unsqueeze_shape[0] = K_matrix.shape[0]
            f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape)
            f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape)
610
611
612
613
614
615
            # get the scaled depth
            sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3]
            # concatenate xy + scaled depth
            xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1)

        # unproject with inverse of the projection
Georgia Gkioxari's avatar
Georgia Gkioxari committed
616
        unprojection_transform = to_ndc_transform.inverse()
617
618
        return unprojection_transform.transform_points(xy_sdepth)

619
620
621
    def is_perspective(self):
        return True

622
623
624
    def in_ndc(self):
        return True

625

Georgia Gkioxari's avatar
Georgia Gkioxari committed
626
627
628
629
630
631
632
633
def OpenGLOrthographicCameras(
    znear=1.0,
    zfar=100.0,
    top=1.0,
    bottom=-1.0,
    left=-1.0,
    right=1.0,
    scale_xyz=((1.0, 1.0, 1.0),),  # (1, 3)
Patrick Labatut's avatar
Patrick Labatut committed
634
635
636
637
    R: torch.Tensor = _R,
    T: torch.Tensor = _T,
    device: Device = "cpu",
) -> "FoVOrthographicCameras":
Georgia Gkioxari's avatar
Georgia Gkioxari committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    """
    OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
    Preserving OpenGLOrthographicCameras for backward compatibility.
    """

    warnings.warn(
        """OpenGLOrthographicCameras is deprecated,
        Use FoVOrthographicCameras instead.
        OpenGLOrthographicCameras will be removed in future releases.""",
        PendingDeprecationWarning,
    )

    return FoVOrthographicCameras(
        znear=znear,
        zfar=zfar,
        max_y=top,
        min_y=bottom,
        max_x=right,
        min_x=left,
        scale_xyz=scale_xyz,
        R=R,
        T=T,
        device=device,
    )


class FoVOrthographicCameras(CamerasBase):
facebook-github-bot's avatar
facebook-github-bot committed
665
666
    """
    A class which stores a batch of parameters to generate a batch of
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
667
    projection matrices by specifying the field of view.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
668
    The definition of the parameters follow the OpenGL orthographic camera.
facebook-github-bot's avatar
facebook-github-bot committed
669
670
671
672
673
674
    """

    def __init__(
        self,
        znear=1.0,
        zfar=100.0,
Georgia Gkioxari's avatar
Georgia Gkioxari committed
675
676
677
678
        max_y=1.0,
        min_y=-1.0,
        max_x=1.0,
        min_x=-1.0,
facebook-github-bot's avatar
facebook-github-bot committed
679
        scale_xyz=((1.0, 1.0, 1.0),),  # (1, 3)
Patrick Labatut's avatar
Patrick Labatut committed
680
681
682
683
        R: torch.Tensor = _R,
        T: torch.Tensor = _T,
        K: Optional[torch.Tensor] = None,
        device: Device = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
684
685
686
687
688
689
    ):
        """

        Args:
            znear: near clipping plane of the view frustrum.
            zfar: far clipping plane of the view frustrum.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
690
691
692
            max_y: maximum y coordinate of the frustrum.
            min_y: minimum y coordinate of the frustrum.
            max_x: maximum x coordinate of the frustrum.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
693
            min_x: minimum x coordinate of the frustrum
facebook-github-bot's avatar
facebook-github-bot committed
694
695
696
            scale_xyz: scale factors for each axis of shape (N, 3).
            R: Rotation matrix of shape (N, 3, 3).
            T: Translation of shape (N, 3).
697
698
            K: (optional) A calibration matrix of shape (N, 4, 4)
                If provided, don't need znear, zfar, max_y, min_y, max_x, min_x, scale_xyz
facebook-github-bot's avatar
facebook-github-bot committed
699
700
            device: torch.device or string.

Georgia Gkioxari's avatar
Georgia Gkioxari committed
701
        Only need to set min_x, max_x, min_y, max_y for viewing frustrums
facebook-github-bot's avatar
facebook-github-bot committed
702
703
704
705
706
707
708
709
        which are non symmetric about the origin.
        """
        # The initializer formats all inputs to torch tensors and broadcasts
        # all the inputs to have the same batch dimension where necessary.
        super().__init__(
            device=device,
            znear=znear,
            zfar=zfar,
Georgia Gkioxari's avatar
Georgia Gkioxari committed
710
711
712
713
            max_y=max_y,
            min_y=min_y,
            max_x=max_x,
            min_x=min_x,
facebook-github-bot's avatar
facebook-github-bot committed
714
715
716
            scale_xyz=scale_xyz,
            R=R,
            T=T,
717
            K=K,
facebook-github-bot's avatar
facebook-github-bot committed
718
719
        )

720
721
722
723
724
725
726
727
728
729
    def compute_projection_matrix(
        self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
    ) -> torch.Tensor:
        """
        Compute the calibration matrix K of shape (N, 4, 4)

        Args:
            znear: near clipping plane of the view frustrum.
            zfar: far clipping plane of the view frustrum.
            max_x: maximum x coordinate of the frustrum.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
730
            min_x: minimum x coordinate of the frustrum
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
            max_y: maximum y coordinate of the frustrum.
            min_y: minimum y coordinate of the frustrum.
            scale_xyz: scale factors for each axis of shape (N, 3).
        """
        K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        # NOTE: OpenGL flips handedness of coordinate system between camera
        # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
        # right handed coordinate system throughout.
        z_sign = +1.0

        K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
        K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
        K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
        K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
        K[:, 3, 3] = ones

        # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
        # the OpenGL z normalization to [-1, 1]
        K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
        K[:, 2, 3] = -znear / (zfar - znear)

        return K

facebook-github-bot's avatar
facebook-github-bot committed
755
756
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
Georgia Gkioxari's avatar
Georgia Gkioxari committed
757
        Calculate the orthographic projection matrix.
facebook-github-bot's avatar
facebook-github-bot committed
758
759
760
761
762
763
        Use column major order.

        Args:
            **kwargs: parameters for the projection can be passed in to
                      override the default values set in __init__.
        Return:
764
            a Transform3d object which represents a batch of projection
Georgia Gkioxari's avatar
Georgia Gkioxari committed
765
               matrices of shape (N, 4, 4)
facebook-github-bot's avatar
facebook-github-bot committed
766
767
768

        .. code-block:: python

Georgia Gkioxari's avatar
Georgia Gkioxari committed
769
770
771
772
773
            scale_x = 2 / (max_x - min_x)
            scale_y = 2 / (max_y - min_y)
            scale_z = 2 / (far-near)
            mid_x = (max_x + min_x) / (max_x - min_x)
            mix_y = (max_y + min_y) / (max_y - min_y)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
774
            mid_z = (far + near) / (far - near)
facebook-github-bot's avatar
facebook-github-bot committed
775

776
            K = [
facebook-github-bot's avatar
facebook-github-bot committed
777
778
779
780
781
782
                    [scale_x,        0,         0,  -mid_x],
                    [0,        scale_y,         0,  -mix_y],
                    [0,              0,  -scale_z,  -mid_z],
                    [0,              0,         0,       1],
            ]
        """
783
        K = kwargs.get("K", self.K)
784
785
786
787
788
789
        if K is not None:
            if K.shape != (self._N, 4, 4):
                msg = "Expected K to have shape of (%r, 4, 4)"
                raise ValueError(msg % (self._N))
        else:
            K = self.compute_projection_matrix(
790
791
792
793
794
795
796
                kwargs.get("znear", self.znear),
                kwargs.get("zfar", self.zfar),
                kwargs.get("max_x", self.max_x),
                kwargs.get("min_x", self.min_x),
                kwargs.get("max_y", self.max_y),
                kwargs.get("min_y", self.min_y),
                kwargs.get("scale_xyz", self.scale_xyz),
797
            )
facebook-github-bot's avatar
facebook-github-bot committed
798

799
800
801
        transform = Transform3d(
            matrix=K.transpose(1, 2).contiguous(), device=self.device
        )
facebook-github-bot's avatar
facebook-github-bot committed
802
803
        return transform

804
805
806
807
808
809
810
811
    def unproject_points(
        self,
        xy_depth: torch.Tensor,
        world_coordinates: bool = True,
        scaled_depth_input: bool = False,
        **kwargs
    ) -> torch.Tensor:
        """>!
Georgia Gkioxari's avatar
Georgia Gkioxari committed
812
        FoV cameras further allow for passing depth in world units
813
814
        (`scaled_depth_input=False`) or in the [0, 1]-normalized units
        (`scaled_depth_input=True`)
facebook-github-bot's avatar
facebook-github-bot committed
815
816

        Args:
817
818
819
820
821
822
            scaled_depth_input: If `True`, assumes the input depth is in
                the [0, 1]-normalized units. If `False` the input depth is in
                the world units.
        """

        if world_coordinates:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
823
            to_ndc_transform = self.get_full_projection_transform(**kwargs.copy())
824
        else:
Georgia Gkioxari's avatar
Georgia Gkioxari committed
825
            to_ndc_transform = self.get_projection_transform(**kwargs.copy())
826
827
828
829
830
831

        if scaled_depth_input:
            # the input depth is already scaled
            xy_sdepth = xy_depth
        else:
            # we have to obtain the scaled depth first
832
833
834
835
836
            K = self.get_projection_transform(**kwargs).get_matrix()
            unsqueeze_shape = [1] * K.dim()
            unsqueeze_shape[0] = K.shape[0]
            mid_z = K[:, 3, 2].reshape(unsqueeze_shape)
            scale_z = K[:, 2, 2].reshape(unsqueeze_shape)
837
838
839
840
            scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z
            # cat xy and scaled depth
            xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1)
        # finally invert the transform
Georgia Gkioxari's avatar
Georgia Gkioxari committed
841
        unprojection_transform = to_ndc_transform.inverse()
842
843
        return unprojection_transform.transform_points(xy_sdepth)

844
845
846
    def is_perspective(self):
        return False

847
848
849
    def in_ndc(self):
        return True

850

Georgia Gkioxari's avatar
Georgia Gkioxari committed
851
852
853
854
############################################################
#             MultiView Camera Classes                     #
############################################################
"""
855
Note that the MultiView Cameras accept parameters in NDC space.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
856
857
858
859
"""


def SfMPerspectiveCameras(
Patrick Labatut's avatar
Patrick Labatut committed
860
861
862
863
864
865
    focal_length=1.0,
    principal_point=((0.0, 0.0),),
    R: torch.Tensor = _R,
    T: torch.Tensor = _T,
    device: Device = "cpu",
) -> "PerspectiveCameras":
Georgia Gkioxari's avatar
Georgia Gkioxari committed
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
    """
    SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
    Preserving SfMPerspectiveCameras for backward compatibility.
    """

    warnings.warn(
        """SfMPerspectiveCameras is deprecated,
        Use PerspectiveCameras instead.
        SfMPerspectiveCameras will be removed in future releases.""",
        PendingDeprecationWarning,
    )

    return PerspectiveCameras(
        focal_length=focal_length,
        principal_point=principal_point,
        R=R,
        T=T,
        device=device,
    )


class PerspectiveCameras(CamerasBase):
facebook-github-bot's avatar
facebook-github-bot committed
888
889
890
891
    """
    A class which stores a batch of parameters to generate a batch of
    transformation matrices using the multi-view geometry convention for
    perspective camera.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
892

893
894
    Parameters for this camera are specified in NDC if `in_ndc` is set to True.
    If parameters are specified in screen space, `in_ndc` must be set to False.
facebook-github-bot's avatar
facebook-github-bot committed
895
896
897
    """

    def __init__(
Georgia Gkioxari's avatar
Georgia Gkioxari committed
898
899
900
        self,
        focal_length=1.0,
        principal_point=((0.0, 0.0),),
Patrick Labatut's avatar
Patrick Labatut committed
901
902
903
904
        R: torch.Tensor = _R,
        T: torch.Tensor = _T,
        K: Optional[torch.Tensor] = None,
        device: Device = "cpu",
905
906
        in_ndc: bool = True,
        image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
Patrick Labatut's avatar
Patrick Labatut committed
907
    ) -> None:
facebook-github-bot's avatar
facebook-github-bot committed
908
909
910
911
912
913
914
915
916
        """

        Args:
            focal_length: Focal length of the camera in world units.
                A tensor of shape (N, 1) or (N, 2) for
                square and non-square pixels respectively.
            principal_point: xy coordinates of the center of
                the principal point of the camera in pixels.
                A tensor of shape (N, 2).
917
918
919
            in_ndc: True if camera parameters are specified in NDC.
                If camera parameters are in screen space, it must
                be set to False.
facebook-github-bot's avatar
facebook-github-bot committed
920
921
            R: Rotation matrix of shape (N, 3, 3)
            T: Translation matrix of shape (N, 3)
922
            K: (optional) A calibration matrix of shape (N, 4, 4)
923
924
925
                If provided, don't need focal_length, principal_point
            image_size: (height, width) of image size.
                A tensor of shape (N, 2). Required for screen cameras.
facebook-github-bot's avatar
facebook-github-bot committed
926
927
928
929
            device: torch.device or string
        """
        # The initializer formats all inputs to torch tensors and broadcasts
        # all the inputs to have the same batch dimension where necessary.
930
        kwargs = {"image_size": image_size} if image_size is not None else {}
facebook-github-bot's avatar
facebook-github-bot committed
931
932
933
934
935
936
        super().__init__(
            device=device,
            focal_length=focal_length,
            principal_point=principal_point,
            R=R,
            T=T,
937
            K=K,
938
939
            _in_ndc=in_ndc,
            **kwargs,  # pyre-ignore
facebook-github-bot's avatar
facebook-github-bot committed
940
        )
941
942
943
944
945
        if image_size is not None:
            if (self.image_size < 1).any():  # pyre-ignore
                raise ValueError("Image_size provided has invalid values")
        else:
            self.image_size = None
facebook-github-bot's avatar
facebook-github-bot committed
946
947
948
949
950
951
952
953
954
955
956

    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the projection matrix using the
        multi-view geometry convention.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in __init__.

        Returns:
957
            A `Transform3d` object with a batch of `N` projection transforms.
facebook-github-bot's avatar
facebook-github-bot committed
958
959
960

        .. code-block:: python

961
962
963
964
            fx = focal_length[:, 0]
            fy = focal_length[:, 1]
            px = principal_point[:, 0]
            py = principal_point[:, 1]
facebook-github-bot's avatar
facebook-github-bot committed
965

966
            K = [
967
968
                    [fx,   0,   px,   0],
                    [0,   fy,   py,   0],
facebook-github-bot's avatar
facebook-github-bot committed
969
970
971
972
                    [0,    0,    0,   1],
                    [0,    0,    1,   0],
            ]
        """
973
        K = kwargs.get("K", self.K)
974
975
976
977
978
979
980
981
        if K is not None:
            if K.shape != (self._N, 4, 4):
                msg = "Expected K to have shape of (%r, 4, 4)"
                raise ValueError(msg % (self._N))
        else:
            K = _get_sfm_calibration_matrix(
                self._N,
                self.device,
982
983
                kwargs.get("focal_length", self.focal_length),
                kwargs.get("principal_point", self.principal_point),
984
985
                orthographic=False,
            )
facebook-github-bot's avatar
facebook-github-bot committed
986

987
988
989
        transform = Transform3d(
            matrix=K.transpose(1, 2).contiguous(), device=self.device
        )
facebook-github-bot's avatar
facebook-github-bot committed
990
991
        return transform

992
993
994
995
    def unproject_points(
        self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
    ) -> torch.Tensor:
        if world_coordinates:
996
            to_camera_transform = self.get_full_projection_transform(**kwargs)
997
        else:
998
            to_camera_transform = self.get_projection_transform(**kwargs)
999

1000
        unprojection_transform = to_camera_transform.inverse()
1001
1002
1003
1004
        xy_inv_depth = torch.cat(
            (xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1  # type: ignore
        )
        return unprojection_transform.transform_points(xy_inv_depth)
facebook-github-bot's avatar
facebook-github-bot committed
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    def get_principal_point(self, **kwargs) -> torch.Tensor:
        """
        Return the camera's principal point

        Args:
            **kwargs: parameters for the camera extrinsics can be passed in
                as keyword arguments to override the default values
                set in __init__.
        """
        proj_mat = self.get_projection_transform(**kwargs).get_matrix()
        return proj_mat[:, 2, :2]

    def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
        """
        Returns the transform from camera projection space (screen or NDC) to NDC space.
        If the camera is defined already in NDC space, the transform is identity.
        For cameras defined in screen space, we adjust the principal point computation
        which is defined in the image space (commonly) and scale the points to NDC space.

        Important: This transforms assumes PyTorch3D conventions for the input points,
        i.e. +X left, +Y up.
        """
        if self.in_ndc():
            ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
        else:
            # when cameras are defined in screen/image space, the principal point is
            # provided in the (+X right, +Y down), aka image, coordinate system.
            # Since input points are defined in the PyTorch3D system (+X left, +Y up),
            # we need to adjust for the principal point transform.
            pr_point_fix = torch.zeros(
                (self._N, 4, 4), device=self.device, dtype=torch.float32
            )
            pr_point_fix[:, 0, 0] = 1.0
            pr_point_fix[:, 1, 1] = 1.0
            pr_point_fix[:, 2, 2] = 1.0
            pr_point_fix[:, 3, 3] = 1.0
            pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
            pr_point_fix_transform = Transform3d(
                matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
            )
            screen_to_ndc_transform = get_screen_to_ndc_transform(
                self, with_xyflip=False, **kwargs
            )
            ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)

        return ndc_transform

1053
1054
1055
    def is_perspective(self):
        return True

1056
1057
1058
    def in_ndc(self):
        return self._in_ndc

facebook-github-bot's avatar
facebook-github-bot committed
1059

Georgia Gkioxari's avatar
Georgia Gkioxari committed
1060
def SfMOrthographicCameras(
Patrick Labatut's avatar
Patrick Labatut committed
1061
1062
1063
1064
1065
1066
    focal_length=1.0,
    principal_point=((0.0, 0.0),),
    R: torch.Tensor = _R,
    T: torch.Tensor = _T,
    device: Device = "cpu",
) -> "OrthographicCameras":
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    """
    SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
    Preserving SfMOrthographicCameras for backward compatibility.
    """

    warnings.warn(
        """SfMOrthographicCameras is deprecated,
        Use OrthographicCameras instead.
        SfMOrthographicCameras will be removed in future releases.""",
        PendingDeprecationWarning,
    )

    return OrthographicCameras(
        focal_length=focal_length,
        principal_point=principal_point,
        R=R,
        T=T,
        device=device,
    )


class OrthographicCameras(CamerasBase):
facebook-github-bot's avatar
facebook-github-bot committed
1089
1090
1091
1092
    """
    A class which stores a batch of parameters to generate a batch of
    transformation matrices using the multi-view geometry convention for
    orthographic camera.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1093

1094
1095
    Parameters for this camera are specified in NDC if `in_ndc` is set to True.
    If parameters are specified in screen space, `in_ndc` must be set to False.
facebook-github-bot's avatar
facebook-github-bot committed
1096
1097
1098
    """

    def __init__(
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1099
1100
1101
        self,
        focal_length=1.0,
        principal_point=((0.0, 0.0),),
Patrick Labatut's avatar
Patrick Labatut committed
1102
1103
1104
1105
        R: torch.Tensor = _R,
        T: torch.Tensor = _T,
        K: Optional[torch.Tensor] = None,
        device: Device = "cpu",
1106
1107
        in_ndc: bool = True,
        image_size: Optional[torch.Tensor] = None,
Patrick Labatut's avatar
Patrick Labatut committed
1108
    ) -> None:
facebook-github-bot's avatar
facebook-github-bot committed
1109
1110
1111
1112
1113
1114
1115
1116
1117
        """

        Args:
            focal_length: Focal length of the camera in world units.
                A tensor of shape (N, 1) or (N, 2) for
                square and non-square pixels respectively.
            principal_point: xy coordinates of the center of
                the principal point of the camera in pixels.
                A tensor of shape (N, 2).
1118
1119
            in_ndc: True if camera parameters are specified in NDC.
                If False, then camera parameters are in screen space.
facebook-github-bot's avatar
facebook-github-bot committed
1120
1121
            R: Rotation matrix of shape (N, 3, 3)
            T: Translation matrix of shape (N, 3)
1122
1123
            K: (optional) A calibration matrix of shape (N, 4, 4)
                If provided, don't need focal_length, principal_point, image_size
1124
1125
            image_size: (height, width) of image size.
                A tensor of shape (N, 2). Required for screen cameras.
facebook-github-bot's avatar
facebook-github-bot committed
1126
1127
1128
1129
            device: torch.device or string
        """
        # The initializer formats all inputs to torch tensors and broadcasts
        # all the inputs to have the same batch dimension where necessary.
1130
        kwargs = {"image_size": image_size} if image_size is not None else {}
facebook-github-bot's avatar
facebook-github-bot committed
1131
1132
1133
1134
1135
1136
        super().__init__(
            device=device,
            focal_length=focal_length,
            principal_point=principal_point,
            R=R,
            T=T,
1137
            K=K,
1138
1139
            _in_ndc=in_ndc,
            **kwargs,  # pyre-ignore
facebook-github-bot's avatar
facebook-github-bot committed
1140
        )
1141
1142
1143
1144
1145
        if image_size is not None:
            if (self.image_size < 1).any():  # pyre-ignore
                raise ValueError("Image_size provided has invalid values")
        else:
            self.image_size = None
facebook-github-bot's avatar
facebook-github-bot committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155

    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the projection matrix using
        the multi-view geometry convention.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in __init__.

1156
        Returns:
1157
            A `Transform3d` object with a batch of `N` projection transforms.
facebook-github-bot's avatar
facebook-github-bot committed
1158
1159
1160
1161
1162
1163
1164
1165

        .. code-block:: python

            fx = focal_length[:,0]
            fy = focal_length[:,1]
            px = principal_point[:,0]
            py = principal_point[:,1]

1166
            K = [
facebook-github-bot's avatar
facebook-github-bot committed
1167
1168
1169
1170
1171
1172
                    [fx,   0,    0,  px],
                    [0,   fy,    0,  py],
                    [0,    0,    1,   0],
                    [0,    0,    0,   1],
            ]
        """
1173
        K = kwargs.get("K", self.K)
1174
1175
1176
1177
1178
1179
1180
1181
        if K is not None:
            if K.shape != (self._N, 4, 4):
                msg = "Expected K to have shape of (%r, 4, 4)"
                raise ValueError(msg % (self._N))
        else:
            K = _get_sfm_calibration_matrix(
                self._N,
                self.device,
1182
1183
                kwargs.get("focal_length", self.focal_length),
                kwargs.get("principal_point", self.principal_point),
1184
1185
                orthographic=True,
            )
facebook-github-bot's avatar
facebook-github-bot committed
1186

1187
1188
1189
        transform = Transform3d(
            matrix=K.transpose(1, 2).contiguous(), device=self.device
        )
facebook-github-bot's avatar
facebook-github-bot committed
1190
1191
        return transform

1192
1193
1194
1195
    def unproject_points(
        self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
    ) -> torch.Tensor:
        if world_coordinates:
1196
            to_camera_transform = self.get_full_projection_transform(**kwargs)
1197
        else:
1198
            to_camera_transform = self.get_projection_transform(**kwargs)
facebook-github-bot's avatar
facebook-github-bot committed
1199

1200
        unprojection_transform = to_camera_transform.inverse()
1201
        return unprojection_transform.transform_points(xy_depth)
facebook-github-bot's avatar
facebook-github-bot committed
1202

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    def get_principal_point(self, **kwargs) -> torch.Tensor:
        """
        Return the camera's principal point

        Args:
            **kwargs: parameters for the camera extrinsics can be passed in
                as keyword arguments to override the default values
                set in __init__.
        """
        proj_mat = self.get_projection_transform(**kwargs).get_matrix()
        return proj_mat[:, 3, :2]

    def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
        """
        Returns the transform from camera projection space (screen or NDC) to NDC space.
        If the camera is defined already in NDC space, the transform is identity.
        For cameras defined in screen space, we adjust the principal point computation
        which is defined in the image space (commonly) and scale the points to NDC space.

        Important: This transforms assumes PyTorch3D conventions for the input points,
        i.e. +X left, +Y up.
        """
        if self.in_ndc():
            ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
        else:
            # when cameras are defined in screen/image space, the principal point is
            # provided in the (+X right, +Y down), aka image, coordinate system.
            # Since input points are defined in the PyTorch3D system (+X left, +Y up),
            # we need to adjust for the principal point transform.
            pr_point_fix = torch.zeros(
                (self._N, 4, 4), device=self.device, dtype=torch.float32
            )
            pr_point_fix[:, 0, 0] = 1.0
            pr_point_fix[:, 1, 1] = 1.0
            pr_point_fix[:, 2, 2] = 1.0
            pr_point_fix[:, 3, 3] = 1.0
            pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
            pr_point_fix_transform = Transform3d(
                matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
            )
            screen_to_ndc_transform = get_screen_to_ndc_transform(
                self, with_xyflip=False, **kwargs
            )
            ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)

        return ndc_transform

1250
1251
1252
    def is_perspective(self):
        return False

1253
1254
1255
    def in_ndc(self):
        return self._in_ndc

facebook-github-bot's avatar
facebook-github-bot committed
1256

Georgia Gkioxari's avatar
Georgia Gkioxari committed
1257
1258
1259
1260
1261
################################################
#       Helper functions for cameras           #
################################################


facebook-github-bot's avatar
facebook-github-bot committed
1262
def _get_sfm_calibration_matrix(
Patrick Labatut's avatar
Patrick Labatut committed
1263
1264
    N: int,
    device: Device,
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1265
1266
1267
    focal_length,
    principal_point,
    orthographic: bool = False,
facebook-github-bot's avatar
facebook-github-bot committed
1268
1269
) -> torch.Tensor:
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1270
    Returns a calibration matrix of a perspective/orthographic camera.
facebook-github-bot's avatar
facebook-github-bot committed
1271
1272
1273

    Args:
        N: Number of cameras.
1274
        focal_length: Focal length of the camera.
facebook-github-bot's avatar
facebook-github-bot committed
1275
1276
        principal_point: xy coordinates of the center of
            the principal point of the camera in pixels.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1277
        orthographic: Boolean specifying if the camera is orthographic or not
facebook-github-bot's avatar
facebook-github-bot committed
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296

        The calibration matrix `K` is set up as follows:

        .. code-block:: python

            fx = focal_length[:,0]
            fy = focal_length[:,1]
            px = principal_point[:,0]
            py = principal_point[:,1]

            for orthographic==True:
                K = [
                        [fx,   0,    0,  px],
                        [0,   fy,    0,  py],
                        [0,    0,    1,   0],
                        [0,    0,    0,   1],
                ]
            else:
                K = [
1297
1298
                        [fx,   0,   px,   0],
                        [0,   fy,   py,   0],
facebook-github-bot's avatar
facebook-github-bot committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
                        [0,    0,    0,   1],
                        [0,    0,    1,   0],
                ]

    Returns:
        A calibration matrix `K` of the SfM-conventioned camera
        of shape (N, 4, 4).
    """

    if not torch.is_tensor(focal_length):
        focal_length = torch.tensor(focal_length, device=device)

Georgia Gkioxari's avatar
Georgia Gkioxari committed
1311
    if focal_length.ndim in (0, 1) or focal_length.shape[1] == 1:
facebook-github-bot's avatar
facebook-github-bot committed
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
        fx = fy = focal_length
    else:
        fx, fy = focal_length.unbind(1)

    if not torch.is_tensor(principal_point):
        principal_point = torch.tensor(principal_point, device=device)

    px, py = principal_point.unbind(1)

    K = fx.new_zeros(N, 4, 4)
    K[:, 0, 0] = fx
    K[:, 1, 1] = fy
    if orthographic:
1325
1326
        K[:, 0, 3] = px
        K[:, 1, 3] = py
facebook-github-bot's avatar
facebook-github-bot committed
1327
1328
1329
        K[:, 2, 2] = 1.0
        K[:, 3, 3] = 1.0
    else:
1330
1331
        K[:, 0, 2] = px
        K[:, 1, 2] = py
facebook-github-bot's avatar
facebook-github-bot committed
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
        K[:, 3, 2] = 1.0
        K[:, 2, 3] = 1.0

    return K


################################################
# Helper functions for world to view transforms
################################################


Patrick Labatut's avatar
Patrick Labatut committed
1343
1344
1345
def get_world_to_view_transform(
    R: torch.Tensor = _R, T: torch.Tensor = _T
) -> Transform3d:
facebook-github-bot's avatar
facebook-github-bot committed
1346
1347
1348
1349
1350
    """
    This function returns a Transform3d representing the transformation
    matrix to go from world space to view space by applying a rotation and
    a translation.

1351
    PyTorch3D uses the same convention as Hartley & Zisserman.
facebook-github-bot's avatar
facebook-github-bot committed
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
    I.e., for camera extrinsic parameters R (rotation) and T (translation),
    we map a 3D point `X_world` in world coordinates to
    a point `X_cam` in camera coordinates with:
    `X_cam = X_world R + T`

    Args:
        R: (N, 3, 3) matrix representing the rotation.
        T: (N, 3) matrix representing the translation.

    Returns:
        a Transform3d object which represents the composed RT transformation.

    """
    # TODO: also support the case where RT is specified as one matrix
    # of shape (N, 4, 4).

    if T.shape[0] != R.shape[0]:
        msg = "Expected R, T to have the same batch dimension; got %r, %r"
        raise ValueError(msg % (R.shape[0], T.shape[0]))
    if T.dim() != 2 or T.shape[1:] != (3,):
        msg = "Expected T to have shape (N, 3); got %r"
        raise ValueError(msg % repr(T.shape))
    if R.dim() != 3 or R.shape[1:] != (3, 3):
        msg = "Expected R to have shape (N, 3, 3); got %r"
1376
        raise ValueError(msg % repr(R.shape))
facebook-github-bot's avatar
facebook-github-bot committed
1377
1378

    # Create a Transform3d object
Patrick Labatut's avatar
Patrick Labatut committed
1379
1380
1381
    T_ = Translate(T, device=T.device)
    R_ = Rotate(R, device=R.device)
    return R_.compose(T_)
facebook-github-bot's avatar
facebook-github-bot committed
1382
1383
1384


def camera_position_from_spherical_angles(
Patrick Labatut's avatar
Patrick Labatut committed
1385
1386
1387
1388
1389
    distance: float,
    elevation: float,
    azimuth: float,
    degrees: bool = True,
    device: Device = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
) -> torch.Tensor:
    """
    Calculate the location of the camera based on the distance away from
    the target point, the elevation and azimuth angles.

    Args:
        distance: distance of the camera from the object.
        elevation, azimuth: angles.
            The inputs distance, elevation and azimuth can be one of the following
                - Python scalar
                - Torch scalar
                - Torch tensor of shape (N) or (1)
        degrees: bool, whether the angles are specified in degrees or radians.
        device: str or torch.device, device for new tensors to be placed on.

    The vectors are broadcast against each other so they all have shape (N, 1).

    Returns:
        camera_position: (N, 3) xyz location of the camera.
    """
    broadcasted_args = convert_to_tensors_and_broadcast(
        distance, elevation, azimuth, device=device
    )
    dist, elev, azim = broadcasted_args
    if degrees:
        elev = math.pi / 180.0 * elev
        azim = math.pi / 180.0 * azim
    x = dist * torch.cos(elev) * torch.sin(azim)
    y = dist * torch.sin(elev)
1419
    z = dist * torch.cos(elev) * torch.cos(azim)
facebook-github-bot's avatar
facebook-github-bot committed
1420
1421
1422
1423
1424
1425
1426
    camera_position = torch.stack([x, y, z], dim=1)
    if camera_position.dim() == 0:
        camera_position = camera_position.view(1, -1)  # add batch dim.
    return camera_position.view(-1, 3)


def look_at_rotation(
Patrick Labatut's avatar
Patrick Labatut committed
1427
    camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu"
facebook-github-bot's avatar
facebook-github-bot committed
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
) -> torch.Tensor:
    """
    This function takes a vector 'camera_position' which specifies the location
    of the camera in world coordinates and two vectors `at` and `up` which
    indicate the position of the object and the up directions of the world
    coordinate system respectively. The object is assumed to be centered at
    the origin.

    The output is a rotation matrix representing the transformation
    from world coordinates -> view coordinates.

    Args:
        camera_position: position of the camera in world coordinates
        at: position of the object in world coordinates
        up: vector specifying the up direction in the world coordinate frame.

    The inputs camera_position, at and up can each be a
        - 3 element tuple/list
        - torch tensor of shape (1, 3)
        - torch tensor of shape (N, 3)

    The vectors are broadcast against each other so they all have shape (N, 3).

    Returns:
        R: (N, 3, 3) batched rotation matrices
    """
    # Format input and broadcast
    broadcasted_args = convert_to_tensors_and_broadcast(
        camera_position, at, up, device=device
    )
    camera_position, at, up = broadcasted_args
    for t, n in zip([camera_position, at, up], ["camera_position", "at", "up"]):
        if t.shape[-1] != 3:
            msg = "Expected arg %s to have shape (N, 3); got %r"
            raise ValueError(msg % (n, t.shape))
    z_axis = F.normalize(at - camera_position, eps=1e-5)
1464
1465
    x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5)
    y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5)
Amitav Baruah's avatar
Amitav Baruah committed
1466
1467
1468
1469
1470
1471
    is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(
        dim=1, keepdim=True
    )
    if is_close.any():
        replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5)
        x_axis = torch.where(is_close, replacement, x_axis)
1472
    R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
facebook-github-bot's avatar
facebook-github-bot committed
1473
1474
1475
1476
    return R.transpose(1, 2)


def look_at_view_transform(
1477
1478
1479
    dist=1.0,
    elev=0.0,
    azim=0.0,
facebook-github-bot's avatar
facebook-github-bot committed
1480
    degrees: bool = True,
1481
    eye: Optional[Sequence] = None,
facebook-github-bot's avatar
facebook-github-bot committed
1482
1483
    at=((0, 0, 0),),  # (1, 3)
    up=((0, 1, 0),),  # (1, 3)
Patrick Labatut's avatar
Patrick Labatut committed
1484
    device: Device = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
1485
1486
1487
1488
1489
1490
1491
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    This function returns a rotation and translation matrix
    to apply the 'Look At' transformation from world -> view coordinates [0].

    Args:
        dist: distance of the camera from the object
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1492
        elev: angle in degrees or radians. This is the angle between the
1493
            vector from the object to the camera, and the horizontal plane y = 0 (xz-plane).
facebook-github-bot's avatar
facebook-github-bot committed
1494
        azim: angle in degrees or radians. The vector from the object to
1495
            the camera is projected onto a horizontal plane y = 0.
facebook-github-bot's avatar
facebook-github-bot committed
1496
            azim is the angle between the projected vector and a
1497
            reference vector at (0, 0, 1) on the reference plane (the horizontal plane).
1498
        dist, elev and azim can be of shape (1), (N).
facebook-github-bot's avatar
facebook-github-bot committed
1499
        degrees: boolean flag to indicate if the elevation and azimuth
1500
1501
            angles are specified in degrees or radians.
        eye: the position of the camera(s) in world coordinates. If eye is not
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1502
            None, it will override the camera position derived from dist, elev, azim.
facebook-github-bot's avatar
facebook-github-bot committed
1503
1504
        up: the direction of the x axis in the world coordinate system.
        at: the position of the object(s) in world coordinates.
1505
        eye, up and at can be of shape (1, 3) or (N, 3).
facebook-github-bot's avatar
facebook-github-bot committed
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515

    Returns:
        2-element tuple containing

        - **R**: the rotation to apply to the points to align with the camera.
        - **T**: the translation to apply to the points to align with the camera.

    References:
    [0] https://www.scratchapixel.com
    """
1516
1517

    if eye is not None:
1518
        broadcasted_args = convert_to_tensors_and_broadcast(eye, at, up, device=device)
1519
1520
1521
1522
        eye, at, up = broadcasted_args
        C = eye
    else:
        broadcasted_args = convert_to_tensors_and_broadcast(
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1523
1524
            dist, elev, azim, at, up, device=device
        )
1525
        dist, elev, azim, at, up = broadcasted_args
1526
1527
1528
1529
1530
        C = (
            camera_position_from_spherical_angles(
                dist, elev, azim, degrees=degrees, device=device
            )
            + at
Georgia Gkioxari's avatar
Georgia Gkioxari committed
1531
        )
1532

facebook-github-bot's avatar
facebook-github-bot committed
1533
1534
1535
    R = look_at_rotation(C, at, up, device=device)
    T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0]
    return R, T
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635


def get_ndc_to_screen_transform(
    cameras, with_xyflip: bool = False, **kwargs
) -> Transform3d:
    """
    PyTorch3D NDC to screen conversion.
    Conversion from PyTorch3D's NDC space (+X left, +Y up) to screen/image space
    (+X right, +Y down, origin top left).

    Args:
        cameras
        with_xyflip: flips x- and y-axis if set to True.
    Optional kwargs:
        image_size: ((height, width),) specifying the height, width
        of the image. If not provided, it reads it from cameras.

    We represent the NDC to screen conversion as a Transform3d
    with projection matrix

    K = [
            [s,   0,    0,  cx],
            [0,   s,    0,  cy],
            [0,   0,    1,   0],
            [0,   0,    0,   1],
    ]

    """
    # We require the image size, which is necessary for the transform
    image_size = kwargs.get("image_size", cameras.get_image_size())
    if image_size is None:
        msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified."
        raise ValueError(msg)

    K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32)
    if not torch.is_tensor(image_size):
        image_size = torch.tensor(image_size, device=cameras.device)
    image_size = image_size.view(-1, 2)  # of shape (1 or B)x2
    height, width = image_size.unbind(1)

    # For non square images, we scale the points such that smallest side
    # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
    # This convention is consistent with the PyTorch3D renderer
    scale = (image_size.min(dim=1).values - 1.0) / 2.0

    K[:, 0, 0] = scale
    K[:, 1, 1] = scale
    K[:, 0, 3] = -1.0 * (width - 1.0) / 2.0
    K[:, 1, 3] = -1.0 * (height - 1.0) / 2.0
    K[:, 2, 2] = 1.0
    K[:, 3, 3] = 1.0

    # Transpose the projection matrix as PyTorch3D transforms use row vectors.
    transform = Transform3d(
        matrix=K.transpose(1, 2).contiguous(), device=cameras.device
    )

    if with_xyflip:
        # flip x, y axis
        xyflip = torch.eye(4, device=cameras.device, dtype=torch.float32)
        xyflip[0, 0] = -1.0
        xyflip[1, 1] = -1.0
        xyflip = xyflip.view(1, 4, 4).expand(cameras._N, -1, -1)
        xyflip_transform = Transform3d(
            matrix=xyflip.transpose(1, 2).contiguous(), device=cameras.device
        )
        transform = transform.compose(xyflip_transform)
    return transform


def get_screen_to_ndc_transform(
    cameras, with_xyflip: bool = False, **kwargs
) -> Transform3d:
    """
    Screen to PyTorch3D NDC conversion.
    Conversion from screen/image space (+X right, +Y down, origin top left)
    to PyTorch3D's NDC space (+X left, +Y up).

    Args:
        cameras
        with_xyflip: flips x- and y-axis if set to True.
    Optional kwargs:
        image_size: ((height, width),) specifying the height, width
        of the image. If not provided, it reads it from cameras.

    We represent the screen to NDC conversion as a Transform3d
    with projection matrix

    K = [
            [1/s,    0,    0,  cx/s],
            [  0,  1/s,    0,  cy/s],
            [  0,    0,    1,     0],
            [  0,    0,    0,     1],
    ]

    """
    transform = get_ndc_to_screen_transform(
        cameras, with_xyflip=with_xyflip, **kwargs
    ).inverse()
    return transform