Commit c371a9a6 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

rasterizer.to without cameras

Summary: As reported in https://github.com/facebookresearch/pytorch3d/pull/1100, a rasterizer couldn't be moved if it was missing the optional cameras member. Fix that. This matters because the renderer.to calls rasterizer.to, so this to() could be called even by a user who never sets a cameras member.

Reviewed By: nikhilaravi

Differential Revision: D34643841

fbshipit-source-id: 7e26e32e8bc585eb1ee533052754a7b59bc7467a
parent 4a1f1760
...@@ -110,7 +110,8 @@ class MeshRasterizer(nn.Module): ...@@ -110,7 +110,8 @@ class MeshRasterizer(nn.Module):
def to(self, device): def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module # Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self return self
def transform(self, meshes_world, **kwargs) -> torch.Tensor: def transform(self, meshes_world, **kwargs) -> torch.Tensor:
......
...@@ -115,7 +115,8 @@ class PointsRasterizer(nn.Module): ...@@ -115,7 +115,8 @@ class PointsRasterizer(nn.Module):
def to(self, device): def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module # Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self return self
def forward(self, point_clouds, **kwargs) -> PointFragments: def forward(self, point_clouds, **kwargs) -> PointFragments:
......
...@@ -134,6 +134,12 @@ class TestMeshRasterizer(unittest.TestCase): ...@@ -134,6 +134,12 @@ class TestMeshRasterizer(unittest.TestCase):
self.assertTrue(torch.allclose(image, image_ref)) self.assertTrue(torch.allclose(image, image_ref))
def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
rasterizer = MeshRasterizer()
rasterizer.to(device)
class TestPointRasterizer(unittest.TestCase): class TestPointRasterizer(unittest.TestCase):
def test_simple_sphere(self): def test_simple_sphere(self):
...@@ -203,3 +209,9 @@ class TestPointRasterizer(unittest.TestCase): ...@@ -203,3 +209,9 @@ class TestPointRasterizer(unittest.TestCase):
image[image >= 0] = 1.0 image[image >= 0] = 1.0
image[image < 0] = 0.0 image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, image_ref[..., 0])) self.assertTrue(torch.allclose(image, image_ref[..., 0]))
def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
rasterizer = PointsRasterizer()
rasterizer.to(device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment