"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cd2b7f071cbfb207400e70deaf4b1b2c46c3e47c"
Commit a15c33a3 authored by Nikhila Ravi's avatar Nikhila Ravi Committed by Facebook GitHub Bot
Browse files

Alpha channel to return the mask

Summary: Updated the alpha channel in the `hard_rgb_blend` function to return the mask of the pixels which have overlapping mesh faces.

Reviewed By: bottler

Differential Revision: D29001604

fbshipit-source-id: 22a2173d769f2d3ad34892d68ceb628f073bca22
parent ac6c07fa
...@@ -58,7 +58,8 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: ...@@ -58,7 +58,8 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
) # (N, H, W, 3) ) # (N, H, W, 3)
# Concat with the alpha channel. # Concat with the alpha channel.
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device) alpha = (~is_background).type_as(pixel_colors)[..., None]
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4) return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
......
...@@ -184,8 +184,8 @@ class TestBlending(TestCaseMixin, unittest.TestCase): ...@@ -184,8 +184,8 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
channel_color = blend_params.background_color[i] channel_color = blend_params.background_color[i]
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all()) self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
# Examine the alpha channel is correct # Examine the alpha channel
self.assertTrue(images[..., 3].eq(1).all()) self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float())
def test_sigmoid_alpha_blend_manual_gradients(self): def test_sigmoid_alpha_blend_manual_gradients(self):
# Create dummy outputs of rasterization # Create dummy outputs of rasterization
......
...@@ -125,6 +125,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ...@@ -125,6 +125,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
) )
images, fragments = renderer(sphere_mesh) images, fragments = renderer(sphere_mesh)
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf) self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
# Check the alpha channel is the mask
self.assertClose(
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
)
else: else:
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh) images = renderer(sphere_mesh)
...@@ -165,6 +169,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ...@@ -165,6 +169,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose( self.assertClose(
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
) )
# Check the alpha channel is the mask
self.assertClose(
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
)
else: else:
phong_renderer = MeshRenderer( phong_renderer = MeshRenderer(
rasterizer=rasterizer, shader=phong_shader rasterizer=rasterizer, shader=phong_shader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment