You need to sign in or sign up before continuing.
Commit 17ca6ecd authored by Nikhila Ravi's avatar Nikhila Ravi Committed by Facebook GitHub Bot
Browse files

allow cameras to be None in rasterizer initialization

Summary: Fix to enable a mesh/point rasterizer to be initialized without having to specify the camera.

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21362359

fbshipit-source-id: 4f84ea18ad9f179c7b7c2289ebf9422a2f5e26de
parent 9c5ab571
...@@ -79,6 +79,7 @@ ...@@ -79,6 +79,7 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"collapsed": true,
"id": "w9mH5iVprQdZ" "id": "w9mH5iVprQdZ"
}, },
"outputs": [], "outputs": [],
...@@ -260,7 +261,7 @@ ...@@ -260,7 +261,7 @@
" cameras=cameras, \n", " cameras=cameras, \n",
" raster_settings=raster_settings\n", " raster_settings=raster_settings\n",
" ),\n", " ),\n",
" shader=HardPhongShader(device=device, lights=lights)\n", " shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
")" ")"
] ]
}, },
......
...@@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module): ...@@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module):
Meshes. Meshes.
""" """
def __init__(self, cameras, raster_settings=None): def __init__(self, cameras=None, raster_settings=None):
""" """
Args: Args:
cameras: A cameras object which has a `transform_points` method cameras: A cameras object which has a `transform_points` method
...@@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module): ...@@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module):
be moved into forward. be moved into forward.
""" """
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of MeshRasterizer"
raise ValueError(msg)
verts_world = meshes_world.verts_padded() verts_world = meshes_world.verts_padded()
verts_world_packed = meshes_world.verts_packed() verts_world_packed = meshes_world.verts_packed()
verts_screen = cameras.transform_points(verts_world, **kwargs) verts_screen = cameras.transform_points(verts_world, **kwargs)
......
...@@ -10,7 +10,6 @@ from ..blending import ( ...@@ -10,7 +10,6 @@ from ..blending import (
sigmoid_alpha_blend, sigmoid_alpha_blend,
softmax_rgb_blend, softmax_rgb_blend,
) )
from ..cameras import OpenGLPerspectiveCameras
from ..lighting import PointLights from ..lighting import PointLights
from ..materials import Materials from ..materials import Materials
from .shading import flat_shading, gouraud_shading, phong_shading from .shading import flat_shading, gouraud_shading, phong_shading
...@@ -46,13 +45,16 @@ class HardPhongShader(nn.Module): ...@@ -46,13 +45,16 @@ class HardPhongShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
colors = phong_shading( colors = phong_shading(
...@@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module): ...@@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
colors = phong_shading( colors = phong_shading(
...@@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module): ...@@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
pixel_colors = gouraud_shading( pixel_colors = gouraud_shading(
...@@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module): ...@@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
pixel_colors = gouraud_shading( pixel_colors = gouraud_shading(
...@@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module): ...@@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_texture_map(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_texture_map(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params) blend_params = kwargs.get("blend_params", self.blend_params)
...@@ -262,13 +272,15 @@ class HardFlatShader(nn.Module): ...@@ -262,13 +272,15 @@ class HardFlatShader(nn.Module):
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = ( self.cameras = cameras
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
colors = flat_shading( colors = flat_shading(
......
...@@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module): ...@@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module):
This class implements methods for rasterizing a batch of pointclouds. This class implements methods for rasterizing a batch of pointclouds.
""" """
def __init__(self, cameras, raster_settings=None): def __init__(self, cameras=None, raster_settings=None):
""" """
cameras: A cameras object which has a `transform_points` method cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the which returns the transformed points after applying the
...@@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module): ...@@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module):
be moved into forward. be moved into forward.
""" """
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of PointsRasterizer"
raise ValueError(msg)
pts_world = point_clouds.points_padded() pts_world = point_clouds.points_padded()
pts_world_packed = point_clouds.points_packed() pts_world_packed = point_clouds.points_packed()
......
...@@ -9,6 +9,11 @@ import torch ...@@ -9,6 +9,11 @@ import torch
from PIL import Image from PIL import Image
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.points.rasterizer import (
PointsRasterizationSettings,
PointsRasterizer,
)
from pytorch3d.structures import Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
...@@ -99,3 +104,101 @@ class TestMeshRasterizer(unittest.TestCase): ...@@ -99,3 +104,101 @@ class TestMeshRasterizer(unittest.TestCase):
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png" DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
) )
self.assertTrue(torch.allclose(image, image_ref)) self.assertTrue(torch.allclose(image, image_ref))
#################################
# 4. Test init without cameras.
##################################
# Create a new empty rasterizer:
rasterizer = MeshRasterizer()
# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
rasterizer(sphere_mesh)
# Now pass in the cameras as a kwarg
fragments = rasterizer(
sphere_mesh, cameras=cameras, raster_settings=raster_settings
)
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
image[image >= 0] = 1.0
image[image < 0] = 0.0
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
)
self.assertTrue(torch.allclose(image, image_ref))
class TestPointRasterizer(unittest.TestCase):
def test_simple_sphere(self):
device = torch.device("cuda:0")
# Load reference image
ref_filename = "test_simple_pointcloud_sphere.png"
image_ref_filename = DATA_DIR / ref_filename
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
image_ref = convert_image_to_binary_mask(image_ref_filename).to(torch.int32)
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(points=verts_padded)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
#################################
# 1. Test init without cameras.
##################################
# Initialize without passing in the cameras
rasterizer = PointsRasterizer()
# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
rasterizer(pointclouds)
##########################################
# 2. Test rasterizing a single pointcloud
##########################################
fragments = rasterizer(
pointclouds, cameras=cameras, raster_settings=raster_settings
)
# Convert idx to a binary mask
image = fragments.idx[0, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere_points.png"
)
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
########################################
# 3. Test with a batch of pointclouds
########################################
batch_size = 10
pointclouds = pointclouds.extend(batch_size)
fragments = rasterizer(
pointclouds, cameras=cameras, raster_settings=raster_settings
)
for i in range(batch_size):
image = fragments.idx[i, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment