Commit c5a83f46 authored by Krzysztof Chalupka's avatar Krzysztof Chalupka Committed by Facebook GitHub Bot
Browse files

SplatterBlender

Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader.

Reviewed By: jcjohnson

Differential Revision: D36354301

fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d
parent 1702c85b
......@@ -57,6 +57,7 @@ from .mesh import (
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
Textures,
TexturesAtlas,
TexturesUV,
......@@ -71,6 +72,7 @@ from .points import (
PulsarPointsRenderer,
rasterize_points,
)
from .splatter_blend import SplatterBlender
from .utils import (
convert_to_tensors_and_broadcast,
ndc_grid_sample,
......
......@@ -4,14 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import NamedTuple, Sequence, Union
import torch
from pytorch3d import _C
from pytorch3d.common.datatypes import Device
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
......@@ -22,10 +20,12 @@ class BlendParams(NamedTuple):
Data class to store blending params with defaults
Members:
sigma (float): Controls the width of the sigmoid function used to
calculate the 2D distance based probability. Determines the
sharpness of the edges of the shape.
Higher => faces have less defined edges.
sigma (float): For SoftmaxPhong, controls the width of the sigmoid
function used to calculate the 2D distance based probability. Determines
the sharpness of the edges of the shape. Higher => faces have less defined
edges. For SplatterPhong, this is the standard deviation of the Gaussian
kernel. Higher => splats have a stronger effect and the rendered image is
more blurry.
gamma (float): Controls the scaling of the exponential function used
to set the opacity of the color.
Higher => faces are more transparent.
......@@ -36,6 +36,7 @@ class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
background_alpha: float = 0.0
def _get_background_color(
......
......@@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader,
)
from .shading import gouraud_shading, phong_shading
......
......@@ -20,9 +20,15 @@ from ..blending import (
)
from ..lighting import PointLights
from ..materials import Materials
from ..splatter_blend import SplatterBlender
from ..utils import TensorProperties
from .rasterizer import Fragments
from .shading import flat_shading, gouraud_shading, phong_shading
from .shading import (
_phong_shading_with_pixels,
flat_shading,
gouraud_shading,
phong_shading,
)
# A Shader should take as input fragments from the output of rasterization
......@@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module):
blend_params = kwargs.get("blend_params", self.blend_params)
images = sigmoid_alpha_blend(colors, fragments, blend_params)
return images
class SplatterPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
color aggregated using splats from surrounding pixels (see [0]).
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SplatterPhongShader(device=torch.device("cuda:0"))
Args:
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
rasterizer.
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
Sampling".
"""
def __init__(self, **kwargs):
self.splatter_blender = None
super().__init__(**kwargs)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SplatterPhongShader."
raise ValueError(msg)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors, pixel_coords_cameras = _phong_shading_with_pixels(
meshes=meshes,
fragments=fragments.detach(),
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
if not self.splatter_blender:
# Init only once, to avoid re-computing constants.
N, H, W, K, _ = colors.shape
self.splatter_blender = SplatterBlender((N, H, W, K), colors.device)
images = self.splatter_blender(
colors,
pixel_coords_cameras,
cameras,
fragments.pix_to_face < 0,
self.blend_params,
)
return images
This diff is collapsed.
......@@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import (
HardPhongShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader,
)
from pytorch3d.structures.meshes import (
......@@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
ShaderTest(SplatterPhongShader, "phong", "splatter_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment