Commit d57daa6f authored by Patrick Labatut's avatar Patrick Labatut Committed by Facebook GitHub Bot
Browse files

Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
parent eb512ffd
...@@ -59,9 +59,7 @@ def vert_align( ...@@ -59,9 +59,7 @@ def vert_align(
elif hasattr(verts, "verts_padded"): elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded() grid = verts.verts_padded()
else: else:
raise ValueError( raise ValueError("verts must be a tensor or have a `verts_padded` attribute")
"verts must be a tensor or have a `verts_padded` attribute"
)
grid = grid[:, None, :, :2] # (N, 1, V, 2) grid = grid[:, None, :, :2] # (N, 1, V, 2)
......
...@@ -44,4 +44,5 @@ from .points import ( ...@@ -44,4 +44,5 @@ from .points import (
) )
from .utils import TensorProperties, convert_to_tensors_and_broadcast from .utils import TensorProperties, convert_to_tensors_and_broadcast
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence
import numpy as np
import torch import torch
# Example functions for blending the top K colors per pixel using the outputs # Example functions for blending the top K colors per pixel using the outputs
# from rasterization. # from rasterization.
# NOTE: All blending function should return an RGBA image per batch element # NOTE: All blending function should return an RGBA image per batch element
...@@ -63,9 +65,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: ...@@ -63,9 +65,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
3D Reasoning', ICCV 2019 3D Reasoning', ICCV 2019
""" """
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
pixel_colors = torch.ones( pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
mask = fragments.pix_to_face >= 0 mask = fragments.pix_to_face >= 0
# The distance is negative if a pixel is inside a face and positive outside # The distance is negative if a pixel is inside a face and positive outside
...@@ -124,14 +124,10 @@ def softmax_rgb_blend( ...@@ -124,14 +124,10 @@ def softmax_rgb_blend(
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device device = fragments.pix_to_face.device
pixel_colors = torch.ones( pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
background = blend_params.background_color background = blend_params.background_color
if not torch.is_tensor(background): if not torch.is_tensor(background):
background = torch.tensor( background = torch.tensor(background, dtype=torch.float32, device=device)
background, dtype=torch.float32, device=device
)
# Background color # Background color
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math import math
import numpy as np
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.transforms import Rotate, Transform3d, Translate from pytorch3d.transforms import Rotate, Transform3d, Translate
from .utils import TensorProperties, convert_to_tensors_and_broadcast from .utils import TensorProperties, convert_to_tensors_and_broadcast
# Default values for rotation and translation matrices. # Default values for rotation and translation matrices.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
...@@ -106,9 +107,7 @@ class OpenGLPerspectiveCameras(TensorProperties): ...@@ -106,9 +107,7 @@ class OpenGLPerspectiveCameras(TensorProperties):
aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio) aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio)
degrees = kwargs.get("degrees", self.degrees) degrees = kwargs.get("degrees", self.degrees)
P = torch.zeros( P = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
(self._N, 4, 4), device=self.device, dtype=torch.float32
)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device) ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
if degrees: if degrees:
fov = (np.pi / 180) * fov fov = (np.pi / 180) * fov
...@@ -204,9 +203,7 @@ class OpenGLPerspectiveCameras(TensorProperties): ...@@ -204,9 +203,7 @@ class OpenGLPerspectiveCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform( world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
return world_to_view_transform return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d: def get_full_projection_transform(self, **kwargs) -> Transform3d:
...@@ -229,9 +226,7 @@ class OpenGLPerspectiveCameras(TensorProperties): ...@@ -229,9 +226,7 @@ class OpenGLPerspectiveCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform( world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs) view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform) return world_to_view_transform.compose(view_to_screen_transform)
...@@ -337,9 +332,7 @@ class OpenGLOrthographicCameras(TensorProperties): ...@@ -337,9 +332,7 @@ class OpenGLOrthographicCameras(TensorProperties):
bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16] bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16]
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16] scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
P = torch.zeros( P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
(self._N, 4, 4), dtype=torch.float32, device=self.device
)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device) ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
# NOTE: OpenGL flips handedness of coordinate system between camera # NOTE: OpenGL flips handedness of coordinate system between camera
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
...@@ -417,9 +410,7 @@ class OpenGLOrthographicCameras(TensorProperties): ...@@ -417,9 +410,7 @@ class OpenGLOrthographicCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform( world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
return world_to_view_transform return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d: def get_full_projection_transform(self, **kwargs) -> Transform3d:
...@@ -442,9 +433,7 @@ class OpenGLOrthographicCameras(TensorProperties): ...@@ -442,9 +433,7 @@ class OpenGLOrthographicCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform( world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs) view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform) return world_to_view_transform.compose(view_to_screen_transform)
...@@ -470,12 +459,7 @@ class SfMPerspectiveCameras(TensorProperties): ...@@ -470,12 +459,7 @@ class SfMPerspectiveCameras(TensorProperties):
""" """
def __init__( def __init__(
self, self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
): ):
""" """
__init__(self, focal_length, principal_point, R, T, device) -> None __init__(self, focal_length, principal_point, R, T, device) -> None
...@@ -589,9 +573,7 @@ class SfMPerspectiveCameras(TensorProperties): ...@@ -589,9 +573,7 @@ class SfMPerspectiveCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform( world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
return world_to_view_transform return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d: def get_full_projection_transform(self, **kwargs) -> Transform3d:
...@@ -610,9 +592,7 @@ class SfMPerspectiveCameras(TensorProperties): ...@@ -610,9 +592,7 @@ class SfMPerspectiveCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform( world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs) view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform) return world_to_view_transform.compose(view_to_screen_transform)
...@@ -638,12 +618,7 @@ class SfMOrthographicCameras(TensorProperties): ...@@ -638,12 +618,7 @@ class SfMOrthographicCameras(TensorProperties):
""" """
def __init__( def __init__(
self, self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
): ):
""" """
__init__(self, focal_length, principal_point, R, T, device) -> None __init__(self, focal_length, principal_point, R, T, device) -> None
...@@ -757,9 +732,7 @@ class SfMOrthographicCameras(TensorProperties): ...@@ -757,9 +732,7 @@ class SfMOrthographicCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform( world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
return world_to_view_transform return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d: def get_full_projection_transform(self, **kwargs) -> Transform3d:
...@@ -778,9 +751,7 @@ class SfMOrthographicCameras(TensorProperties): ...@@ -778,9 +751,7 @@ class SfMOrthographicCameras(TensorProperties):
""" """
self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform( world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs) view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform) return world_to_view_transform.compose(view_to_screen_transform)
...@@ -990,9 +961,7 @@ def look_at_rotation( ...@@ -990,9 +961,7 @@ def look_at_rotation(
z_axis = F.normalize(at - camera_position, eps=1e-5) z_axis = F.normalize(at - camera_position, eps=1e-5)
x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5) x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5)
y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5) y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5)
R = torch.cat( R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
(x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1
)
return R.transpose(1, 2) return R.transpose(1, 2)
...@@ -1038,9 +1007,7 @@ def look_at_view_transform( ...@@ -1038,9 +1007,7 @@ def look_at_view_transform(
""" """
if eye is not None: if eye is not None:
broadcasted_args = convert_to_tensors_and_broadcast( broadcasted_args = convert_to_tensors_and_broadcast(eye, at, up, device=device)
eye, at, up, device=device
)
eye, at, up = broadcasted_args eye, at, up = broadcasted_args
C = eye C = eye
else: else:
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
from typing import NamedTuple from typing import NamedTuple
import torch
import torch
from pytorch3d import _C from pytorch3d import _C
# Example functions for blending the top K features per pixel using the outputs # Example functions for blending the top K features per pixel using the outputs
# from rasterization. # from rasterization.
# NOTE: All blending function should return a (N, H, W, C) tensor per batch element. # NOTE: All blending function should return a (N, H, W, C) tensor per batch element.
...@@ -49,9 +50,7 @@ class _CompositeAlphaPoints(torch.autograd.Function): ...@@ -49,9 +50,7 @@ class _CompositeAlphaPoints(torch.autograd.Function):
def forward(ctx, features, alphas, points_idx): def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_alphacomposite(features, alphas, points_idx) pt_cld = _C.accum_alphacomposite(features, alphas, points_idx)
ctx.save_for_backward( ctx.save_for_backward(features.clone(), alphas.clone(), points_idx.clone())
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld return pt_cld
@staticmethod @staticmethod
...@@ -68,9 +67,7 @@ class _CompositeAlphaPoints(torch.autograd.Function): ...@@ -68,9 +67,7 @@ class _CompositeAlphaPoints(torch.autograd.Function):
return grad_features, grad_alphas, grad_points_idx, None return grad_features, grad_alphas, grad_points_idx, None
def alpha_composite( def alpha_composite(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
pointsidx, alphas, pt_clds, blend_params=None
) -> torch.Tensor:
""" """
Composite features within a z-buffer using alpha compositing. Given a zbuffer Composite features within a z-buffer using alpha compositing. Given a zbuffer
with corresponding features and weights, these values are accumulated according with corresponding features and weights, these values are accumulated according
...@@ -131,9 +128,7 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function): ...@@ -131,9 +128,7 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function):
def forward(ctx, features, alphas, points_idx): def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_weightedsumnorm(features, alphas, points_idx) pt_cld = _C.accum_weightedsumnorm(features, alphas, points_idx)
ctx.save_for_backward( ctx.save_for_backward(features.clone(), alphas.clone(), points_idx.clone())
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld return pt_cld
@staticmethod @staticmethod
...@@ -150,9 +145,7 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function): ...@@ -150,9 +145,7 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function):
return grad_features, grad_alphas, grad_points_idx, None return grad_features, grad_alphas, grad_points_idx, None
def norm_weighted_sum( def norm_weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
pointsidx, alphas, pt_clds, blend_params=None
) -> torch.Tensor:
""" """
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
with corresponding features and weights, these values are accumulated with corresponding features and weights, these values are accumulated
...@@ -213,9 +206,7 @@ class _CompositeWeightedSumPoints(torch.autograd.Function): ...@@ -213,9 +206,7 @@ class _CompositeWeightedSumPoints(torch.autograd.Function):
def forward(ctx, features, alphas, points_idx): def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_weightedsum(features, alphas, points_idx) pt_cld = _C.accum_weightedsum(features, alphas, points_idx)
ctx.save_for_backward( ctx.save_for_backward(features.clone(), alphas.clone(), points_idx.clone())
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld return pt_cld
@staticmethod @staticmethod
......
...@@ -114,12 +114,7 @@ def specular( ...@@ -114,12 +114,7 @@ def specular(
# Ensure all inputs have same batch dimension as points # Ensure all inputs have same batch dimension as points
matched_tensors = convert_to_tensors_and_broadcast( matched_tensors = convert_to_tensors_and_broadcast(
points, points, color, direction, camera_position, shininess, device=points.device
color,
direction,
camera_position,
shininess,
device=points.device,
) )
_, color, direction, camera_position, shininess = matched_tensors _, color, direction, camera_position, shininess = matched_tensors
...@@ -201,9 +196,7 @@ class DirectionalLights(TensorProperties): ...@@ -201,9 +196,7 @@ class DirectionalLights(TensorProperties):
normals=normals, color=self.diffuse_color, direction=self.direction normals=normals, color=self.diffuse_color, direction=self.direction
) )
def specular( def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
self, normals, points, camera_position, shininess
) -> torch.Tensor:
return specular( return specular(
points=points, points=points,
normals=normals, normals=normals,
...@@ -256,13 +249,9 @@ class PointLights(TensorProperties): ...@@ -256,13 +249,9 @@ class PointLights(TensorProperties):
def diffuse(self, normals, points) -> torch.Tensor: def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points direction = self.location - points
return diffuse( return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
normals=normals, color=self.diffuse_color, direction=direction
)
def specular( def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
self, normals, points, camera_position, shininess
) -> torch.Tensor:
direction = self.location - points direction = self.location - points
return specular( return specular(
points=points, points=points,
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .texturing import ( # isort:skip from .texturing import interpolate_texture_map, interpolate_vertex_colors # isort:skip
interpolate_texture_map,
interpolate_vertex_colors,
)
from .rasterize_meshes import rasterize_meshes from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer from .renderer import MeshRenderer
...@@ -20,4 +17,5 @@ from .shader import ( ...@@ -20,4 +17,5 @@ from .shader import (
from .shading import gouraud_shading, phong_shading from .shading import gouraud_shading, phong_shading
from .utils import interpolate_face_attributes from .utils import interpolate_face_attributes
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import Optional from typing import Optional
import torch
import numpy as np
import torch
from pytorch3d import _C from pytorch3d import _C
# TODO make the epsilon user configurable # TODO make the epsilon user configurable
kEpsilon = 1e-30 kEpsilon = 1e-30
...@@ -172,9 +173,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): ...@@ -172,9 +173,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
return pix_to_face, zbuf, barycentric_coords, dists return pix_to_face, zbuf, barycentric_coords, dists
@staticmethod @staticmethod
def backward( def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists):
ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists
):
grad_face_verts = None grad_face_verts = None
grad_mesh_to_face_first_idx = None grad_mesh_to_face_first_idx = None
grad_num_faces_per_mesh = None grad_num_faces_per_mesh = None
...@@ -243,9 +242,7 @@ def rasterize_meshes_python( ...@@ -243,9 +242,7 @@ def rasterize_meshes_python(
face_idxs = torch.full( face_idxs = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.int64, device=device (N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
) )
zbuf = torch.full( zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
bary_coords = torch.full( bary_coords = torch.full(
(N, H, W, K, 3), fill_value=-1, dtype=torch.float32, device=device (N, H, W, K, 3), fill_value=-1, dtype=torch.float32, device=device
) )
...@@ -308,9 +305,7 @@ def rasterize_meshes_python( ...@@ -308,9 +305,7 @@ def rasterize_meshes_python(
continue continue
# Compute barycentric coordinates and pixel z distance. # Compute barycentric coordinates and pixel z distance.
pxy = torch.tensor( pxy = torch.tensor([xf, yf], dtype=torch.float32, device=device)
[xf, yf], dtype=torch.float32, device=device
)
bary = barycentric_coordinates(pxy, v0[:2], v1[:2], v2[:2]) bary = barycentric_coordinates(pxy, v0[:2], v1[:2], v2[:2])
if perspective_correct: if perspective_correct:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -123,8 +124,5 @@ class MeshRasterizer(nn.Module): ...@@ -123,8 +124,5 @@ class MeshRasterizer(nn.Module):
perspective_correct=raster_settings.perspective_correct, perspective_correct=raster_settings.perspective_correct,
) )
return Fragments( return Fragments(
pix_to_face=pix_to_face, pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists
zbuf=zbuf,
bary_coords=bary_coords,
dists=dists,
) )
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from .rasterizer import Fragments from .rasterizer import Fragments
from .utils import _clip_barycentric_coordinates, _interpolate_zbuf from .utils import _clip_barycentric_coordinates, _interpolate_zbuf
# A renderer class should be initialized with a # A renderer class should be initialized with a
# function for rasterization and a function for shading. # function for rasterization and a function for shading.
# The rasterizer should: # The rasterizer should:
...@@ -48,16 +49,12 @@ class MeshRenderer(nn.Module): ...@@ -48,16 +49,12 @@ class MeshRenderer(nn.Module):
the range for the corresponding face. the range for the corresponding face.
""" """
fragments = self.rasterizer(meshes_world, **kwargs) fragments = self.rasterizer(meshes_world, **kwargs)
raster_settings = kwargs.get( raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings)
"raster_settings", self.rasterizer.raster_settings
)
if raster_settings.blur_radius > 0.0: if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer # TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values. # if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer. # This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates( clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
fragments.bary_coords
)
clipped_zbuf = _interpolate_zbuf( clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, meshes_world fragments.pix_to_face, clipped_bary_coords, meshes_world
) )
......
...@@ -16,6 +16,7 @@ from ..materials import Materials ...@@ -16,6 +16,7 @@ from ..materials import Materials
from .shading import flat_shading, gouraud_shading, phong_shading from .shading import flat_shading, gouraud_shading, phong_shading
from .texturing import interpolate_texture_map, interpolate_vertex_colors from .texturing import interpolate_texture_map, interpolate_vertex_colors
# A Shader should take as input fragments from the output of rasterization # A Shader should take as input fragments from the output of rasterization
# along with scene params and output images. A shader could perform operations # along with scene params and output images. A shader could perform operations
# such as: # such as:
...@@ -41,16 +42,12 @@ class HardPhongShader(nn.Module): ...@@ -41,16 +42,12 @@ class HardPhongShader(nn.Module):
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
) )
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
...@@ -85,28 +82,17 @@ class SoftPhongShader(nn.Module): ...@@ -85,28 +82,17 @@ class SoftPhongShader(nn.Module):
""" """
def __init__( def __init__(
self, self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
) )
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes) texels = interpolate_vertex_colors(fragments, meshes)
...@@ -142,16 +128,12 @@ class HardGouraudShader(nn.Module): ...@@ -142,16 +128,12 @@ class HardGouraudShader(nn.Module):
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
) )
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
...@@ -185,28 +167,17 @@ class SoftGouraudShader(nn.Module): ...@@ -185,28 +167,17 @@ class SoftGouraudShader(nn.Module):
""" """
def __init__( def __init__(
self, self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
) )
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
...@@ -241,28 +212,17 @@ class TexturedSoftPhongShader(nn.Module): ...@@ -241,28 +212,17 @@ class TexturedSoftPhongShader(nn.Module):
""" """
def __init__( def __init__(
self, self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
) )
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_texture_map(fragments, meshes) texels = interpolate_texture_map(fragments, meshes)
...@@ -298,16 +258,12 @@ class HardFlatShader(nn.Module): ...@@ -298,16 +258,12 @@ class HardFlatShader(nn.Module):
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__() super().__init__()
self.lights = ( self.lights = lights if lights is not None else PointLights(device=device)
lights if lights is not None else PointLights(device=device)
)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = (
cameras cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
) )
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
...@@ -346,9 +302,7 @@ class SoftSilhouetteShader(nn.Module): ...@@ -346,9 +302,7 @@ class SoftSilhouetteShader(nn.Module):
def __init__(self, blend_params=None): def __init__(self, blend_params=None):
super().__init__() super().__init__()
self.blend_params = ( self.blend_params = blend_params if blend_params is not None else BlendParams()
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
"""" """"
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from typing import Tuple from typing import Tuple
import torch import torch
from .texturing import interpolate_face_attributes from .texturing import interpolate_face_attributes
...@@ -82,9 +83,7 @@ def phong_shading( ...@@ -82,9 +83,7 @@ def phong_shading(
return colors return colors
def gouraud_shading( def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tensor:
meshes, fragments, lights, cameras, materials
) -> torch.Tensor:
""" """
Apply per vertex shading. First compute the vertex illumination by applying Apply per vertex shading. First compute the vertex illumination by applying
ambient, diffuse and specular lighting. If vertex color is available, ambient, diffuse and specular lighting. If vertex color is available,
...@@ -131,9 +130,7 @@ def gouraud_shading( ...@@ -131,9 +130,7 @@ def gouraud_shading(
return colors return colors
def flat_shading( def flat_shading(meshes, fragments, lights, cameras, materials, texels) -> torch.Tensor:
meshes, fragments, lights, cameras, materials, texels
) -> torch.Tensor:
""" """
Apply per face shading. Use the average face position and the face normals Apply per face shading. Use the average face position and the face normals
to compute the ambient, diffuse and specular lighting. Apply the ambient to compute the ambient, diffuse and specular lighting. Apply the ambient
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.structures.textures import Textures from pytorch3d.structures.textures import Textures
from .utils import interpolate_face_attributes from .utils import interpolate_face_attributes
...@@ -75,9 +74,7 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor: ...@@ -75,9 +74,7 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
# right-bottom pixel of input. # right-bottom pixel of input.
pixel_uvs = pixel_uvs * 2.0 - 1.0 pixel_uvs = pixel_uvs * 2.0 - 1.0
texture_maps = torch.flip( texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
texture_maps, [2]
) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device: if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device) texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
...@@ -107,9 +104,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: ...@@ -107,9 +104,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
There will be one C dimensional value for each element in There will be one C dimensional value for each element in
fragments.pix_to_face. fragments.pix_to_face.
""" """
vertex_textures = meshes.textures.verts_rgb_padded().reshape( vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C)
-1, 3
) # (V, C)
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
faces_packed = meshes.faces_packed() faces_packed = meshes.faces_packed()
faces_textures = vertex_textures[faces_packed] # (F, 3, C) faces_textures = vertex_textures[faces_packed] # (F, 3, C)
......
...@@ -92,8 +92,6 @@ def _interpolate_zbuf( ...@@ -92,8 +92,6 @@ def _interpolate_zbuf(
verts = meshes.verts_packed() verts = meshes.verts_packed()
faces = meshes.faces_packed() faces = meshes.faces_packed()
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
return interpolate_face_attributes( return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
pix_to_face, barycentric_coords, faces_verts_z
)[
..., 0 ..., 0
] # (1, H, W, K) ] # (1, H, W, K)
...@@ -5,4 +5,5 @@ from .rasterize_points import rasterize_points ...@@ -5,4 +5,5 @@ from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer from .rasterizer import PointsRasterizationSettings, PointsRasterizer
from .renderer import PointsRenderer from .renderer import PointsRenderer
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
from ..compositing import CompositeParams, alpha_composite, norm_weighted_sum from ..compositing import CompositeParams, alpha_composite, norm_weighted_sum
# A compositor should take as input 3D points and some corresponding information. # A compositor should take as input 3D points and some corresponding information.
# Given this information, the compositor can: # Given this information, the compositor can:
# - blend colors across the top K vertices at a pixel # - blend colors across the top K vertices at a pixel
...@@ -19,15 +20,11 @@ class AlphaCompositor(nn.Module): ...@@ -19,15 +20,11 @@ class AlphaCompositor(nn.Module):
super().__init__() super().__init__()
self.composite_params = ( self.composite_params = (
composite_params composite_params if composite_params is not None else CompositeParams()
if composite_params is not None
else CompositeParams()
) )
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = alpha_composite( images = alpha_composite(fragments, alphas, ptclds, self.composite_params)
fragments, alphas, ptclds, self.composite_params
)
return images return images
...@@ -39,13 +36,9 @@ class NormWeightedCompositor(nn.Module): ...@@ -39,13 +36,9 @@ class NormWeightedCompositor(nn.Module):
def __init__(self, composite_params=None): def __init__(self, composite_params=None):
super().__init__() super().__init__()
self.composite_params = ( self.composite_params = (
composite_params composite_params if composite_params is not None else CompositeParams()
if composite_params is not None
else CompositeParams()
) )
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = norm_weighted_sum( images = norm_weighted_sum(fragments, alphas, ptclds, self.composite_params)
fragments, alphas, ptclds, self.composite_params
)
return images return images
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional from typing import Optional
import torch
import torch
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
...@@ -155,10 +155,7 @@ class _RasterizePoints(torch.autograd.Function): ...@@ -155,10 +155,7 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python( def rasterize_points_python(
pointclouds, pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8
image_size: int = 256,
radius: float = 0.01,
points_per_pixel: int = 8,
): ):
""" """
Naive pure PyTorch implementation of pointcloud rasterization. Naive pure PyTorch implementation of pointcloud rasterization.
...@@ -177,9 +174,7 @@ def rasterize_points_python( ...@@ -177,9 +174,7 @@ def rasterize_points_python(
point_idxs = torch.full( point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
) )
zbuf = torch.full( zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device)
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
)
pix_dists = torch.full( pix_dists = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
) )
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
# A renderer class should be initialized with a # A renderer class should be initialized with a
# function for rasterization and a function for compositing. # function for rasterization and a function for compositing.
# The rasterizer should: # The rasterizer should:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import warnings import warnings
from typing import Any, Union from typing import Any, Union
import numpy as np
import torch import torch
...@@ -45,10 +46,7 @@ class TensorAccessor(object): ...@@ -45,10 +46,7 @@ class TensorAccessor(object):
# Convert the attribute to a tensor if it is not a tensor. # Convert the attribute to a tensor if it is not a tensor.
if not torch.is_tensor(value): if not torch.is_tensor(value):
value = torch.tensor( value = torch.tensor(
value, value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad
device=v.device,
dtype=v.dtype,
requires_grad=v.requires_grad,
) )
# Check the shapes match the existing shape and the shape of the index. # Check the shapes match the existing shape and the shape of the index.
...@@ -253,9 +251,7 @@ class TensorProperties(object): ...@@ -253,9 +251,7 @@ class TensorProperties(object):
return self return self
def format_tensor( def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
input, dtype=torch.float32, device: str = "cpu"
) -> torch.Tensor:
""" """
Helper function for converting a scalar value to a tensor. Helper function for converting a scalar value to a tensor.
...@@ -276,9 +272,7 @@ def format_tensor( ...@@ -276,9 +272,7 @@ def format_tensor(
return input return input
def convert_to_tensors_and_broadcast( def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
*args, dtype=torch.float32, device: str = "cpu"
):
""" """
Helper function to handle parsing an arbitrary number of inputs (*args) Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension. which all need to have the same batch dimension.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment