Commit 36edf2b3 authored by Krzysztof Chalupka's avatar Krzysztof Chalupka Committed by Facebook GitHub Bot
Browse files

Add .to methods to the splatter and SplatterPhongShader.

Summary: Needed to properly change devices during OpenGL rasterization.

Reviewed By: jcjohnson

Differential Revision: D37698568

fbshipit-source-id: 38968149d577322e662d3b5d04880204b0a7be29
parent 78bb6d17
...@@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase): ...@@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase):
self.splatter_blender = None self.splatter_blender = None
super().__init__(**kwargs) super().__init__(**kwargs)
def to(self, device: Device):
if self.splatter_blender:
self.splatter_blender.to(device)
return super().to(device)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs) cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments) texels = meshes.sample_textures(fragments)
...@@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase): ...@@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase):
pixel_coords_cameras, pixel_coords_cameras,
cameras, cameras,
fragments.pix_to_face < 0, fragments.pix_to_face < 0,
self.blend_params, kwargs.get("blend_params", self.blend_params),
) )
return images return images
...@@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase): ...@@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase):
""" """
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
if fragments.dists is None:
raise ValueError("SoftDepthShader requires Fragments.dists to be present.")
cameras = super()._get_cameras(**kwargs) cameras = super()._get_cameras(**kwargs)
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
......
...@@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module): ...@@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module):
input_shape, device input_shape, device
) )
def to(self, device):
self.offsets = self.offsets.to(device)
self.crop_ids_h = self.crop_ids_h.to(device)
self.crop_ids_w = self.crop_ids_w.to(device)
super().to(device)
def forward( def forward(
self, self,
colors: torch.Tensor, colors: torch.Tensor,
......
...@@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase):
self.assertIs(cpu_shader, cuda_shader) self.assertIs(cpu_shader, cuda_shader)
if cameras is None: if cameras is None:
self.assertIsNone(cuda_shader.cameras) self.assertIsNone(cuda_shader.cameras)
with self.assertRaisesRegexp(ValueError, "Cameras must be"): with self.assertRaisesRegex(ValueError, "Cameras must be"):
cuda_shader._get_cameras() cuda_shader._get_cameras()
else: else:
self.assertEqual(cuda_device, cuda_shader.cameras.device) self.assertEqual(cuda_device, cuda_shader.cameras.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment