Commit 83bacda8 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

lint

Summary: Fix recent flake complaints

Reviewed By: MichaelRamamonjisoa

Differential Revision: D51811912

fbshipit-source-id: 65183f5bc7058da910e4d5a63b2250ce8637f1cc
parent f74fc450
[flake8] [flake8]
ignore = E203, E266, E501, W503, E221 # B028 No explicit stacklevel argument found.
# B907 'foo' is manually surrounded by quotes, consider using the `!r` conversion flag.
# B905 `zip()` without an explicit `strict=` parameter.
ignore = E203, E266, E501, W503, E221, B028, B905, B907
max-line-length = 88 max-line-length = 88
max-complexity = 18 max-complexity = 18
select = B,C,E,F,W,T4,B9 select = B,C,E,F,W,T4,B9
......
...@@ -34,11 +34,7 @@ def _minify(basedir, path_manager, factors=(), resolutions=()): ...@@ -34,11 +34,7 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
imgdir = os.path.join(basedir, "images") imgdir = os.path.join(basedir, "images")
imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))] imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
imgs = [ imgs = [f for f in imgs if f.endswith("JPG", "jpg", "png", "jpeg", "PNG")]
f
for f in imgs
if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
]
imgdir_orig = imgdir imgdir_orig = imgdir
wd = os.getcwd() wd = os.getcwd()
......
...@@ -200,7 +200,7 @@ def resize_image( ...@@ -200,7 +200,7 @@ def resize_image(
mode: str = "bilinear", mode: str = "bilinear",
) -> Tuple[torch.Tensor, float, torch.Tensor]: ) -> Tuple[torch.Tensor, float, torch.Tensor]:
if type(image) == np.ndarray: if isinstance(image, np.ndarray):
image = torch.from_numpy(image) image = torch.from_numpy(image)
if image_height is None or image_width is None: if image_height is None or image_width is None:
......
...@@ -750,7 +750,7 @@ def save_obj( ...@@ -750,7 +750,7 @@ def save_obj(
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]]) save_texture = all(t is not None for t in [faces_uvs, verts_uvs, texture_map])
output_path = Path(f) output_path = Path(f)
# Save the .obj file # Save the .obj file
......
...@@ -453,6 +453,6 @@ def parse_image_size( ...@@ -453,6 +453,6 @@ def parse_image_size(
raise ValueError("Image size can only be a tuple/list of (H, W)") raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size): if not all(i > 0 for i in image_size):
raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size) raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size)
if not all(type(i) == int for i in image_size): if not all(isinstance(i, int) for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size) raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
return tuple(image_size) return tuple(image_size)
...@@ -1698,7 +1698,7 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) -> ...@@ -1698,7 +1698,7 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) ->
# Now we know there are multiple meshes and they have textures to merge. # Now we know there are multiple meshes and they have textures to merge.
all_textures = [mesh.textures for mesh in meshes] all_textures = [mesh.textures for mesh in meshes]
first = all_textures[0] first = all_textures[0]
tex_types_same = all(type(tex) == type(first) for tex in all_textures) tex_types_same = all(type(tex) == type(first) for tex in all_textures) # noqa: E721
if not tex_types_same: if not tex_types_same:
raise ValueError("All meshes in the batch must have the same type of texture.") raise ValueError("All meshes in the batch must have the same type of texture.")
......
...@@ -440,22 +440,22 @@ class Transform3d: ...@@ -440,22 +440,22 @@ class Transform3d:
def translate(self, *args, **kwargs) -> "Transform3d": def translate(self, *args, **kwargs) -> "Transform3d":
return self.compose( return self.compose(
Translate(device=self.device, dtype=self.dtype, *args, **kwargs) Translate(*args, device=self.device, dtype=self.dtype, **kwargs)
) )
def scale(self, *args, **kwargs) -> "Transform3d": def scale(self, *args, **kwargs) -> "Transform3d":
return self.compose( return self.compose(
Scale(device=self.device, dtype=self.dtype, *args, **kwargs) Scale(*args, device=self.device, dtype=self.dtype, **kwargs)
) )
def rotate(self, *args, **kwargs) -> "Transform3d": def rotate(self, *args, **kwargs) -> "Transform3d":
return self.compose( return self.compose(
Rotate(device=self.device, dtype=self.dtype, *args, **kwargs) Rotate(*args, device=self.device, dtype=self.dtype, **kwargs)
) )
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
return self.compose( return self.compose(
RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs) RotateAxisAngle(*args, device=self.device, dtype=self.dtype, **kwargs)
) )
def clone(self) -> "Transform3d": def clone(self) -> "Transform3d":
......
...@@ -15,15 +15,14 @@ from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_lo ...@@ -15,15 +15,14 @@ from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_lo
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_prepare_inputs_wrong_num_dim(self): def test_prepare_inputs_wrong_num_dim(self):
img = torch.randn(3, 3, 3) img = torch.randn(3, 3, 3)
with self.assertRaises(ValueError) as context: text = (
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated."
)
with self.assertRaisesRegex(ValueError, text):
img, fg_prob, depth_map = preprocess_input( img, fg_prob, depth_map = preprocess_input(
img, None, None, True, True, 0.5, (0.0, 0.0, 0.0) img, None, None, True, True, 0.5, (0.0, 0.0, 0.0)
) )
self.assertEqual(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated.",
context.exception,
)
def test_prepare_inputs_mask_image_true(self): def test_prepare_inputs_mask_image_true(self):
batch, channels, height, width = 2, 3, 10, 10 batch, channels, height, width = 2, 3, 10, 10
......
...@@ -224,6 +224,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ...@@ -224,6 +224,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
def test_load_mask(self): def test_load_mask(self):
path = os.path.join(self.dataset_root, self.frame_annotation.mask.path) path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
path = self.path_manager.get_local_path(path)
mask = load_mask(path) mask = load_mask(path)
self.assertEqual(mask.dtype, np.float32) self.assertEqual(mask.dtype, np.float32)
self.assertLessEqual(np.max(mask), 1.0) self.assertLessEqual(np.max(mask), 1.0)
...@@ -231,12 +232,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ...@@ -231,12 +232,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
def test_load_depth(self): def test_load_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
path = self.path_manager.get_local_path(path)
depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment) depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment)
self.assertEqual(depth_map.dtype, np.float32) self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 3) self.assertEqual(len(depth_map.shape), 3)
def test_load_16big_png_depth(self): def test_load_16big_png_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
path = self.path_manager.get_local_path(path)
depth_map = load_16big_png_depth(path) depth_map = load_16big_png_depth(path)
self.assertEqual(depth_map.dtype, np.float32) self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 2) self.assertEqual(len(depth_map.shape), 2)
...@@ -245,6 +248,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ...@@ -245,6 +248,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
mask_path = os.path.join( mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path self.dataset_root, self.frame_annotation.depth.mask_path
) )
mask_path = self.path_manager.get_local_path(mask_path)
mask = load_1bit_png_mask(mask_path) mask = load_1bit_png_mask(mask_path)
self.assertEqual(mask.dtype, np.float32) self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 2) self.assertEqual(len(mask.shape), 2)
...@@ -253,6 +257,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ...@@ -253,6 +257,7 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
mask_path = os.path.join( mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path self.dataset_root, self.frame_annotation.depth.mask_path
) )
mask_path = self.path_manager.get_local_path(mask_path)
mask = load_depth_mask(mask_path) mask = load_depth_mask(mask_path)
self.assertEqual(mask.dtype, np.float32) self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 3) self.assertEqual(len(mask.shape), 3)
...@@ -38,22 +38,23 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): ...@@ -38,22 +38,23 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths( def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
self, self,
) -> None: ) -> None:
with self.assertRaises(ValueError) as context:
ray_bundle = ImplicitronRayBundle( ray_bundle = ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3), origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3), directions=torch.rand(2, 3, 4, 3),
lengths=None, lengths=None,
xys=torch.rand(2, 3, 4, 2), xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1), bins=torch.rand(2, 3, 4, 14),
) )
ray_bundle.lengths = torch.empty(2) with self.assertRaisesRegex(
self.assertEqual( ValueError,
str(context.exception),
"If the bins attribute is not None you cannot set the lengths attribute.", "If the bins attribute is not None you cannot set the lengths attribute.",
) ):
ray_bundle.lengths = torch.empty(2)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None: def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError) as context: with self.assertRaisesRegex(
ValueError, "The last dim of bins must be at least superior or equal to 2."
):
ImplicitronRayBundle( ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3), origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3), directions=torch.rand(2, 3, 4, 3),
...@@ -61,15 +62,14 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): ...@@ -61,15 +62,14 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
xys=torch.rand(2, 3, 4, 2), xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1), bins=torch.rand(2, 3, 4, 1),
) )
self.assertEqual(
str(context.exception),
"The last dim of bins must be at least superior or equal to 2.",
)
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided( def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
self, self,
) -> None: ) -> None:
with self.assertRaises(ValueError) as context: with self.assertRaisesRegex(
ValueError,
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
):
ImplicitronRayBundle( ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3), origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3), directions=torch.rand(2, 3, 4, 3),
...@@ -77,10 +77,6 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): ...@@ -77,10 +77,6 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
xys=torch.rand(2, 3, 4, 2), xys=torch.rand(2, 3, 4, 2),
bins=None, bins=None,
) )
self.assertEqual(
str(context.exception),
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
)
def test_conical_frustum_to_gaussian(self) -> None: def test_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(3, 3, 3) origins = torch.zeros(3, 3, 3)
...@@ -266,8 +262,6 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): ...@@ -266,8 +262,6 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
ray = ImplicitronRayBundle( ray = ImplicitronRayBundle(
origins=origins, directions=directions, lengths=lengths, xys=None origins=origins, directions=directions, lengths=lengths, xys=None
) )
with self.assertRaises(ValueError) as context:
_ = conical_frustum_to_gaussian(ray)
expected_error_message = ( expected_error_message = (
"RayBundle pixel_radii_2d or bins have not been provided." "RayBundle pixel_radii_2d or bins have not been provided."
...@@ -276,7 +270,8 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): ...@@ -276,7 +270,8 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
"`cast_ray_bundle_as_cone` to True?" "`cast_ray_bundle_as_cone` to True?"
) )
self.assertEqual(expected_error_message, str(context.exception)) with self.assertRaisesRegex(ValueError, expected_error_message):
_ = conical_frustum_to_gaussian(ray)
# Ensure message is coherent with AbstractMaskRaySampler # Ensure message is coherent with AbstractMaskRaySampler
class FakeRaySampler(AbstractMaskRaySampler): class FakeRaySampler(AbstractMaskRaySampler):
......
...@@ -964,8 +964,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase): ...@@ -964,8 +964,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(IndexError, "out of bounds"): with self.assertRaisesRegex(IndexError, "out of bounds"):
cam[N_CAMERAS] cam[N_CAMERAS]
with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool) index = torch.tensor([1, 0, 1], dtype=torch.bool)
with self.assertRaisesRegex(ValueError, "does not match cameras"):
cam[index] cam[index]
with self.assertRaisesRegex(ValueError, "Invalid index type"): with self.assertRaisesRegex(ValueError, "Invalid index type"):
...@@ -974,8 +974,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase): ...@@ -974,8 +974,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Invalid index type"): with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[[True, False]] cam[[True, False]]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32) index = torch.tensor(SLICE, dtype=torch.float32)
with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[index] cam[index]
def test_get_full_transform(self): def test_get_full_transform(self):
......
...@@ -422,9 +422,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ...@@ -422,9 +422,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_save_obj_invalid_shapes(self): def test_save_obj_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
with self.assertRaises(ValueError) as error:
with NamedTemporaryFile(mode="w", suffix=".obj") as f: with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces) save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
...@@ -433,9 +433,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ...@@ -433,9 +433,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(expected_message, error.exception) self.assertTrue(expected_message, error.exception)
# Invalid faces shape # Invalid faces shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
with self.assertRaises(ValueError) as error:
with NamedTemporaryFile(mode="w", suffix=".obj") as f: with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces) save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
......
...@@ -308,9 +308,9 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ...@@ -308,9 +308,9 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_save_ply_invalid_shapes(self): def test_save_ply_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
with self.assertRaises(ValueError) as error:
save_ply(BytesIO(), verts, faces) save_ply(BytesIO(), verts, faces)
expected_message = ( expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)." "Argument 'verts' should either be empty or of shape (num_verts, 3)."
...@@ -318,9 +318,9 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ...@@ -318,9 +318,9 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(expected_message, error.exception) self.assertTrue(expected_message, error.exception)
# Invalid faces shape # Invalid faces shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
with self.assertRaises(ValueError) as error:
save_ply(BytesIO(), verts, faces) save_ply(BytesIO(), verts, faces)
expected_message = ( expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)." "Argument 'faces' should either be empty or of shape (num_faces, 3)."
......
...@@ -324,17 +324,15 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -324,17 +324,15 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
] ]
faces_list = mesh.faces_list() faces_list = mesh.faces_list()
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "same device"):
Meshes(verts=verts_list, faces=faces_list) Meshes(verts=verts_list, faces=faces_list)
self.assertTrue("same device" in cm.msg)
verts_padded = mesh.verts_padded() # on cpu verts_padded = mesh.verts_padded() # on cpu
verts_padded = verts_padded.to("cuda:0") verts_padded = verts_padded.to("cuda:0")
faces_padded = mesh.faces_padded() faces_padded = mesh.faces_padded()
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "same device"):
Meshes(verts=verts_padded, faces=faces_padded) Meshes(verts=verts_padded, faces=faces_padded)
self.assertTrue("same device" in cm.msg)
def test_simple_random_meshes(self): def test_simple_random_meshes(self):
......
...@@ -148,31 +148,28 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): ...@@ -148,31 +148,28 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
features_list = clouds.features_list() features_list = clouds.features_list()
normals_list = clouds.normals_list() normals_list = clouds.normals_list()
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "same device"):
Pointclouds( Pointclouds(
points=points_list, features=features_list, normals=normals_list points=points_list, features=features_list, normals=normals_list
) )
self.assertTrue("same device" in cm.msg)
points_list = clouds.points_list() points_list = clouds.points_list()
features_list = [ features_list = [
f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list
] ]
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "same device"):
Pointclouds( Pointclouds(
points=points_list, features=features_list, normals=normals_list points=points_list, features=features_list, normals=normals_list
) )
self.assertTrue("same device" in cm.msg)
points_padded = clouds.points_padded() # on cuda:0 points_padded = clouds.points_padded() # on cuda:0
features_padded = clouds.features_padded().to("cpu") features_padded = clouds.features_padded().to("cpu")
normals_padded = clouds.normals_padded() normals_padded = clouds.normals_padded()
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "same device"):
Pointclouds( Pointclouds(
points=points_padded, features=features_padded, normals=normals_padded points=points_padded, features=features_padded, normals=normals_padded
) )
self.assertTrue("same device" in cm.msg)
def test_all_constructions(self): def test_all_constructions(self):
public_getters = [ public_getters = [
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import re
import unittest import unittest
from itertools import product from itertools import product
...@@ -102,62 +103,56 @@ class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase): ...@@ -102,62 +103,56 @@ class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase):
def test_mesh_image_size_arg(self): def test_mesh_image_size_arg(self):
meshes = Meshes(verts=[verts0], faces=[faces0]) meshes = Meshes(verts=[verts0], faces=[faces0])
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, re.escape("tuple/list of (H, W)")):
rasterize_meshes( rasterize_meshes(
meshes, meshes,
(100, 200, 3), (100, 200, 3),
0.0001, 0.0001,
faces_per_pixel=1, faces_per_pixel=1,
) )
self.assertTrue("tuple/list of (H, W)" in cm.msg)
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "sizes must be greater than 0"):
rasterize_meshes( rasterize_meshes(
meshes, meshes,
(0, 10), (0, 10),
0.0001, 0.0001,
faces_per_pixel=1, faces_per_pixel=1,
) )
self.assertTrue("sizes must be positive" in cm.msg)
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "sizes must be integers"):
rasterize_meshes( rasterize_meshes(
meshes, meshes,
(100.5, 120.5), (100.5, 120.5),
0.0001, 0.0001,
faces_per_pixel=1, faces_per_pixel=1,
) )
self.assertTrue("sizes must be integers" in cm.msg)
def test_points_image_size_arg(self): def test_points_image_size_arg(self):
points = Pointclouds([verts0]) points = Pointclouds([verts0])
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, re.escape("tuple/list of (H, W)")):
rasterize_points( rasterize_points(
points, points,
(100, 200, 3), (100, 200, 3),
0.0001, 0.0001,
points_per_pixel=1, points_per_pixel=1,
) )
self.assertTrue("tuple/list of (H, W)" in cm.msg)
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "sizes must be greater than 0"):
rasterize_points( rasterize_points(
points, points,
(0, 10), (0, 10),
0.0001, 0.0001,
points_per_pixel=1, points_per_pixel=1,
) )
self.assertTrue("sizes must be positive" in cm.msg)
with self.assertRaises(ValueError) as cm: with self.assertRaisesRegex(ValueError, "sizes must be integers"):
rasterize_points( rasterize_points(
points, points,
(100.5, 120.5), (100.5, 120.5),
0.0001, 0.0001,
points_per_pixel=1, points_per_pixel=1,
) )
self.assertTrue("sizes must be integers" in cm.msg)
class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase): class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
......
...@@ -419,16 +419,16 @@ class TestMeshRasterizerOpenGLUtils(TestCaseMixin, unittest.TestCase): ...@@ -419,16 +419,16 @@ class TestMeshRasterizerOpenGLUtils(TestCaseMixin, unittest.TestCase):
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings) fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 10, 2047, 1])) self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 10, 2047, 1]))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 512) raster_settings.image_size = (2049, 512)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
rasterizer(self.meshes_world, raster_settings=raster_settings) rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (512, 2049) raster_settings.image_size = (512, 2049)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
rasterizer(self.meshes_world, raster_settings=raster_settings) rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 2049) raster_settings.image_size = (2049, 2049)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
rasterizer(self.meshes_world, raster_settings=raster_settings) rasterizer(self.meshes_world, raster_settings=raster_settings)
......
...@@ -80,8 +80,8 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -80,8 +80,8 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
self.assertClose(x_padded, torch.stack(x, 0)) self.assertClose(x_padded, torch.stack(x, 0))
# catch ValueError for invalid dimensions # catch ValueError for invalid dimensions
with self.assertRaisesRegex(ValueError, "Pad size must"):
pad_size = [K] * (ndim + 1) pad_size = [K] * (ndim + 1)
with self.assertRaisesRegex(ValueError, "Pad size must"):
struct_utils.list_to_padded( struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False x, pad_size=pad_size, pad_value=0.0, equisized=False
) )
...@@ -196,9 +196,9 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -196,9 +196,9 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
# Case 6: Input has more than 3 dims. # Case 6: Input has more than 3 dims.
# Raise an error. # Raise an error.
with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device) x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist() split_size = torch.randint(1, K, size=(N,)).tolist()
with self.assertRaisesRegex(ValueError, "Supports only"):
struct_utils.padded_to_packed(x, split_size=split_size) struct_utils.padded_to_packed(x, split_size=split_size)
def test_list_to_packed(self): def test_list_to_packed(self):
......
...@@ -1055,7 +1055,7 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase): ...@@ -1055,7 +1055,7 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
def test_simple(self): def test_simple(self):
self.assert_bb([(3, 4), (4, 3)], {6, 4}) self.assert_bb([(3, 4), (4, 3)], {6, 4})
self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4}) self.assert_bb([(2, 2), (2, 4), (2, 2)], {4})
# many squares # many squares
self.assert_bb([(2, 2)] * 9, {2, 18}) self.assert_bb([(2, 2)] * 9, {2, 18})
......
...@@ -936,8 +936,8 @@ class TestTransformBroadcast(unittest.TestCase): ...@@ -936,8 +936,8 @@ class TestTransformBroadcast(unittest.TestCase):
y = torch.tensor([0.3] * M) y = torch.tensor([0.3] * M)
z = torch.tensor([0.4] * M) z = torch.tensor([0.4] * M)
tM = Translate(x, y, z) tM = Translate(x, y, z)
with self.assertRaises(ValueError):
t = tN.compose(tM) t = tN.compose(tM)
with self.assertRaises(ValueError):
t.get_matrix() t.get_matrix()
def test_multiple_broadcast_compose(self): def test_multiple_broadcast_compose(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment