"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "80c97e6a78b6f1ead85d04c40badaa1014d8a24d"
Commit 872ff8c7 authored by Amitav Baruah's avatar Amitav Baruah Committed by Facebook GitHub Bot
Browse files

Add background color support to compositors

Summary: Support rendering different color backgrounds for pointclouds for both compositors

Reviewed By: nikhilaravi

Differential Revision: D23611043

fbshipit-source-id: ab029650d51349340372c5bd66700e6577d48851
parent dc40adfa
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module): ...@@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module):
Accumulate points using alpha compositing. Accumulate points using alpha compositing.
""" """
def __init__(self): def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
super().__init__() super().__init__()
self.background_color = background_color
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
background_color = kwargs.get("background_color", self.background_color)
images = alpha_composite(fragments, alphas, ptclds) images = alpha_composite(fragments, alphas, ptclds)
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
return _add_background_color_to_images(fragments, images, background_color)
return images return images
...@@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module): ...@@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module):
Accumulate points using a normalized weighted sum. Accumulate points using a normalized weighted sum.
""" """
def __init__(self): def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
super().__init__() super().__init__()
self.background_color = background_color
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
background_color = kwargs.get("background_color", self.background_color)
images = norm_weighted_sum(fragments, alphas, ptclds) images = norm_weighted_sum(fragments, alphas, ptclds)
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
return _add_background_color_to_images(fragments, images, background_color)
return images return images
def _add_background_color_to_images(pix_idxs, images, background_color):
"""
Mask pixels in images without corresponding points with a given background_color.
Args:
pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
images: Tensor of shape (N, 4, image_size, image_size) giving the
accumulated features at each point, where 4 refers to a rgba feature.
background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba
value for the new background. Values should be in the interval [0,1].
Returns:
images: Tensor of shape (N, 4, image_size, image_size), where pixels with
no nearest points have features set to the background color, and other
pixels with accumulated features have unchanged values.
"""
# Initialize background mask
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
# Convert background_color to an appropriate tensor and check shape
if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color)
background_shape = background_color.shape
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
warnings.warn(
"Background color should be size (3) or (4), but is size %s instead"
% (background_shape,)
)
return images
background_color = background_color.to(images)
# add alpha channel
if background_shape[0] == 3:
alpha = images.new_ones(1)
background_color = torch.cat([background_color, alpha])
num_background_pixels = background_mask.sum()
# permute so that features are the last dimension for masked_scatter to work
masked_images = images.permute(0, 2, 3, 1)[..., :4].masked_scatter(
background_mask[..., None],
background_color[None, :].expand(num_background_pixels, -1),
)
return masked_images.permute(0, 3, 1, 2)
...@@ -18,6 +18,7 @@ from pytorch3d.renderer.cameras import ( ...@@ -18,6 +18,7 @@ from pytorch3d.renderer.cameras import (
FoVPerspectiveCameras, FoVPerspectiveCameras,
look_at_view_transform, look_at_view_transform,
) )
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
from pytorch3d.renderer.points import ( from pytorch3d.renderer.points import (
AlphaCompositor, AlphaCompositor,
NormWeightedCompositor, NormWeightedCompositor,
...@@ -171,3 +172,54 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): ...@@ -171,3 +172,54 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
DATA_DIR / filename DATA_DIR / filename
) )
self.assertClose(rgb, image_ref) self.assertClose(rgb, image_ref)
def test_compositor_background_color(self):
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
ptclds = torch.randn((C, P))
alphas = torch.rand((N, K, H, W))
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
background_color = [0.5, 0, 1]
compositor_funcs = [
(NormWeightedCompositor, norm_weighted_sum),
(AlphaCompositor, alpha_composite),
]
for (compositor_class, composite_func) in compositor_funcs:
compositor = compositor_class(background_color)
# run the forward method to generate masked images
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
# generate unmasked images for testing purposes
images = composite_func(pix_idxs, alphas, ptclds)
is_foreground = pix_idxs[:, 0] >= 0
# make sure foreground values are unchanged
self.assertClose(
torch.masked_select(masked_images, is_foreground[:, None]),
torch.masked_select(images, is_foreground[:, None]),
)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
for i in range(3):
channel_color = background_color[i]
# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, 4)[..., i]
.eq(channel_color)
.all()
)
# check background color alpha values
self.assertTrue(
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment