Commit c5a83f46 authored by Krzysztof Chalupka's avatar Krzysztof Chalupka Committed by Facebook GitHub Bot
Browse files

SplatterBlender

Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader.

Reviewed By: jcjohnson

Differential Revision: D36354301

fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d
parent 1702c85b
...@@ -57,6 +57,7 @@ from .mesh import ( ...@@ -57,6 +57,7 @@ from .mesh import (
SoftGouraudShader, SoftGouraudShader,
SoftPhongShader, SoftPhongShader,
SoftSilhouetteShader, SoftSilhouetteShader,
SplatterPhongShader,
Textures, Textures,
TexturesAtlas, TexturesAtlas,
TexturesUV, TexturesUV,
...@@ -71,6 +72,7 @@ from .points import ( ...@@ -71,6 +72,7 @@ from .points import (
PulsarPointsRenderer, PulsarPointsRenderer,
rasterize_points, rasterize_points,
) )
from .splatter_blend import SplatterBlender
from .utils import ( from .utils import (
convert_to_tensors_and_broadcast, convert_to_tensors_and_broadcast,
ndc_grid_sample, ndc_grid_sample,
......
...@@ -4,14 +4,12 @@ ...@@ -4,14 +4,12 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import NamedTuple, Sequence, Union from typing import NamedTuple, Sequence, Union
import torch import torch
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.common.datatypes import Device from pytorch3d.common.datatypes import Device
# Example functions for blending the top K colors per pixel using the outputs # Example functions for blending the top K colors per pixel using the outputs
# from rasterization. # from rasterization.
# NOTE: All blending function should return an RGBA image per batch element # NOTE: All blending function should return an RGBA image per batch element
...@@ -22,10 +20,12 @@ class BlendParams(NamedTuple): ...@@ -22,10 +20,12 @@ class BlendParams(NamedTuple):
Data class to store blending params with defaults Data class to store blending params with defaults
Members: Members:
sigma (float): Controls the width of the sigmoid function used to sigma (float): For SoftmaxPhong, controls the width of the sigmoid
calculate the 2D distance based probability. Determines the function used to calculate the 2D distance based probability. Determines
sharpness of the edges of the shape. the sharpness of the edges of the shape. Higher => faces have less defined
Higher => faces have less defined edges. edges. For SplatterPhong, this is the standard deviation of the Gaussian
kernel. Higher => splats have a stronger effect and the rendered image is
more blurry.
gamma (float): Controls the scaling of the exponential function used gamma (float): Controls the scaling of the exponential function used
to set the opacity of the color. to set the opacity of the color.
Higher => faces are more transparent. Higher => faces are more transparent.
...@@ -36,6 +36,7 @@ class BlendParams(NamedTuple): ...@@ -36,6 +36,7 @@ class BlendParams(NamedTuple):
sigma: float = 1e-4 sigma: float = 1e-4
gamma: float = 1e-4 gamma: float = 1e-4
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
background_alpha: float = 0.0
def _get_background_color( def _get_background_color(
......
...@@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED ...@@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED
SoftGouraudShader, SoftGouraudShader,
SoftPhongShader, SoftPhongShader,
SoftSilhouetteShader, SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader, TexturedSoftPhongShader,
) )
from .shading import gouraud_shading, phong_shading from .shading import gouraud_shading, phong_shading
......
...@@ -20,9 +20,15 @@ from ..blending import ( ...@@ -20,9 +20,15 @@ from ..blending import (
) )
from ..lighting import PointLights from ..lighting import PointLights
from ..materials import Materials from ..materials import Materials
from ..splatter_blend import SplatterBlender
from ..utils import TensorProperties from ..utils import TensorProperties
from .rasterizer import Fragments from .rasterizer import Fragments
from .shading import flat_shading, gouraud_shading, phong_shading from .shading import (
_phong_shading_with_pixels,
flat_shading,
gouraud_shading,
phong_shading,
)
# A Shader should take as input fragments from the output of rasterization # A Shader should take as input fragments from the output of rasterization
...@@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module): ...@@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module):
blend_params = kwargs.get("blend_params", self.blend_params) blend_params = kwargs.get("blend_params", self.blend_params)
images = sigmoid_alpha_blend(colors, fragments, blend_params) images = sigmoid_alpha_blend(colors, fragments, blend_params)
return images return images
class SplatterPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
color aggregated using splats from surrounding pixels (see [0]).
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SplatterPhongShader(device=torch.device("cuda:0"))
Args:
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
rasterizer.
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
Sampling".
"""
def __init__(self, **kwargs):
self.splatter_blender = None
super().__init__(**kwargs)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SplatterPhongShader."
raise ValueError(msg)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors, pixel_coords_cameras = _phong_shading_with_pixels(
meshes=meshes,
fragments=fragments.detach(),
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
if not self.splatter_blender:
# Init only once, to avoid re-computing constants.
N, H, W, K, _ = colors.shape
self.splatter_blender = SplatterBlender((N, H, W, K), colors.device)
images = self.splatter_blender(
colors,
pixel_coords_cameras,
cameras,
fragments.pix_to_face < 0,
self.blend_params,
)
return images
This diff is collapsed.
...@@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import ( ...@@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import (
HardPhongShader, HardPhongShader,
SoftPhongShader, SoftPhongShader,
SoftSilhouetteShader, SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader, TexturedSoftPhongShader,
) )
from pytorch3d.structures.meshes import ( from pytorch3d.structures.meshes import (
...@@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ...@@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
shader_tests = [ shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"), ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(SoftPhongShader, "phong", "soft_phong"), ShaderTest(SoftPhongShader, "phong", "soft_phong"),
ShaderTest(SplatterPhongShader, "phong", "splatter_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"), ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"), ShaderTest(HardFlatShader, "flat", "hard_flat"),
] ]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment