Commit 83fef0a5 authored by Dave Schnizlein's avatar Dave Schnizlein Committed by Facebook GitHub Bot
Browse files

Add MeshRendererWithFragments class to also return fragments after rendering

Summary: Users want to be able to obtain the depth from the renderer. Current work-around requires running the rasterizer and extra time. This change creates a new renderer class that also returns the fragments from the rasterizer.

Reviewed By: nikhilaravi

Differential Revision: D24432381

fbshipit-source-id: 6552e8a6bfee646791afb34bdb7452fbc4094aed
parent b6be3b95
......@@ -55,3 +55,46 @@ class MeshRenderer(nn.Module):
images = self.shader(fragments, meshes_world, **kwargs)
return images
class MeshRendererWithFragments(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
In the forward pass this class returns the `fragments` from which intermediate
values such as the depth map can be easily extracted e.g.
.. code-block:: python
images, fragments = renderer(meshes)
depth = fragments.zbuf
"""
def __init__(self, rasterizer, shader):
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
def to(self, device):
# Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device)
self.shader.to(device)
def forward(self, meshes_world, **kwargs):
"""
Render a batch of images from a batch of meshes by rasterizing and then
shading.
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
have one or more barycentric coordinates lying outside the range [0, 1].
For a pixel with out of bounds barycentric coordinates with respect to a
face f, clipping is required before interpolating the texture uv
coordinates and z buffer so that the colors and depths are limited to
the range for the corresponding face.
For this set rasterizer.raster_settings.clip_barycentric_coords=True
"""
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)
return images, fragments
......@@ -24,7 +24,7 @@ from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.renderer import MeshRenderer, MeshRendererWithFragments
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
......@@ -50,7 +50,7 @@ DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False):
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
......@@ -114,8 +114,16 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
if check_depth:
renderer = MeshRendererWithFragments(
rasterizer=rasterizer, shader=shader
)
images, fragments = renderer(sphere_mesh)
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
else:
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s.png" % (
name,
......@@ -144,8 +152,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
materials=materials,
blend_params=blend_params,
)
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights)
if check_depth:
phong_renderer = MeshRendererWithFragments(
rasterizer=rasterizer, shader=phong_shader
)
images, fragments = phong_renderer(sphere_mesh, lights=lights)
self.assertClose(
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
)
else:
phong_renderer = MeshRenderer(
rasterizer=rasterizer, shader=phong_shader
)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s%s.png" % (
......@@ -171,6 +190,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"""
self.test_simple_sphere(elevated_camera=True)
def test_simple_sphere_depth(self):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
The rendering is performed with a camera that has non-zero elevation.
"""
self.test_simple_sphere(check_depth=True)
def test_simple_sphere_screen(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment