Commit 9e2bc3a1 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

ambient lights batching #1043

Summary:
convert_to_tensors_and_broadcast had a special case for a single input, which is not used anywhere except fails to do the right thing if a TensorProperties has only one kwarg. At the moment AmbientLights may be the only way to hit the problem. Fix by removing the special case.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1043

Reviewed By: nikhilaravi

Differential Revision: D33638345

fbshipit-source-id: 7a6695f44242e650504320f73b6da74254d49ac7
parent fddd6a70
...@@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast( ...@@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast(
expand_sizes = (N,) + (-1,) * len(c.shape[1:]) expand_sizes = (N,) + (-1,) * len(c.shape[1:])
args_Nd.append(c.expand(*expand_sizes)) args_Nd.append(c.expand(*expand_sizes))
if len(args) == 1:
args_Nd = args_Nd[0] # Return the first element
return args_Nd return args_Nd
...@@ -9,7 +9,7 @@ import unittest ...@@ -9,7 +9,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.renderer.lighting import DirectionalLights, PointLights from pytorch3d.renderer.lighting import AmbientLights, DirectionalLights, PointLights
from pytorch3d.transforms import RotateAxisAngle from pytorch3d.transforms import RotateAxisAngle
...@@ -121,6 +121,17 @@ class TestLights(TestCaseMixin, unittest.TestCase): ...@@ -121,6 +121,17 @@ class TestLights(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
PointLights(location=torch.randn(10, 4)) PointLights(location=torch.randn(10, 4))
def test_initialize_ambient(self):
N = 13
color = 0.8 * torch.ones((N, 3))
lights = AmbientLights(ambient_color=color)
self.assertEqual(len(lights), N)
self.assertClose(lights.ambient_color, color)
lights = AmbientLights(ambient_color=color[:1])
self.assertEqual(len(lights), 1)
self.assertClose(lights.ambient_color, color[:1])
class TestDiffuseLighting(TestCaseMixin, unittest.TestCase): class TestDiffuseLighting(TestCaseMixin, unittest.TestCase):
def test_diffuse_directional_lights(self): def test_diffuse_directional_lights(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment