Commit 5ecce832 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

PyTorch 1.4 compat

Summary: Restore compatibility with PyTorch 1.4 and 1.5, and a few lint fixes.

Reviewed By: patricklabatut

Differential Revision: D30048115

fbshipit-source-id: ee05efa7c625f6079fb06a3cc23be93e48df9433
parent 55aaec4d
...@@ -19,7 +19,7 @@ def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover ...@@ -19,7 +19,7 @@ def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
Like torch.linalg.solve, tries to return X Like torch.linalg.solve, tries to return X
such that AX=B, with A square. such that AX=B, with A square.
""" """
if hasattr(torch.linalg, "solve"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
# PyTorch version >= 1.8.0 # PyTorch version >= 1.8.0
return torch.linalg.solve(A, B) return torch.linalg.solve(A, B)
...@@ -31,7 +31,7 @@ def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover ...@@ -31,7 +31,7 @@ def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
Like torch.linalg.lstsq, tries to return X Like torch.linalg.lstsq, tries to return X
such that AX=B. such that AX=B.
""" """
if hasattr(torch.linalg, "lstsq"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
# PyTorch version >= 1.9 # PyTorch version >= 1.9
return torch.linalg.lstsq(A, B).solution return torch.linalg.lstsq(A, B).solution
...@@ -45,7 +45,7 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove ...@@ -45,7 +45,7 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
""" """
Like torch.linalg.qr. Like torch.linalg.qr.
""" """
if hasattr(torch.linalg, "qr"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
# PyTorch version >= 1.9 # PyTorch version >= 1.9
return torch.linalg.qr(A) return torch.linalg.qr(A)
return torch.qr(A) return torch.qr(A)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import math import math
import warnings import warnings
from typing import Optional, Sequence, Tuple, Union, List from typing import List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -259,8 +259,9 @@ class CamerasBase(TensorProperties): ...@@ -259,8 +259,9 @@ class CamerasBase(TensorProperties):
# users might might have to implement the screen to NDC transform based # users might might have to implement the screen to NDC transform based
# on the definition of the camera parameters. # on the definition of the camera parameters.
# See PerspectiveCameras/OrthographicCameras for an example. # See PerspectiveCameras/OrthographicCameras for an example.
# We don't flip xy because we assume that world points are in PyTorch3D coodrinates # We don't flip xy because we assume that world points are in
# and thus conversion from screen to ndc is a mere scaling from image to [-1, 1] scale. # PyTorch3D coordinates, and thus conversion from screen to ndc
# is a mere scaling from image to [-1, 1] scale.
return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs) return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)
def transform_points_ndc( def transform_points_ndc(
......
...@@ -551,17 +551,15 @@ class PulsarPointsRenderer(nn.Module): ...@@ -551,17 +551,15 @@ class PulsarPointsRenderer(nn.Module):
otherargs["bg_col"] = bg_col otherargs["bg_col"] = bg_col
# Go! # Go!
images.append( images.append(
torch.flipud( self.renderer(
self.renderer( vert_pos=vert_pos,
vert_pos=vert_pos, vert_col=vert_col,
vert_col=vert_col, vert_rad=vert_rad,
vert_rad=vert_rad, cam_params=cam_params,
cam_params=cam_params, gamma=gamma,
gamma=gamma, max_depth=zfar,
max_depth=zfar, min_depth=znear,
min_depth=znear, **otherargs,
**otherargs, ).flip(dims=[0])
)
)
) )
return torch.stack(images, dim=0) return torch.stack(images, dim=0)
...@@ -140,8 +140,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: ...@@ -140,8 +140,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
dim=-2, dim=-2,
) )
# clipping is not important here; if q_abs is small, the candidate won't be picked # We floor here at 0.1 but the exact level is not important; if q_abs is small,
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clip(0.1)) # the candidate won't be picked.
# pyre-ignore [16]: `torch.Tensor` has no attribute `new_tensor`.
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign), # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator) # forall i; we pick the best-conditioned one (with the largest denominator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment