Commit d57daa6f authored by Patrick Labatut's avatar Patrick Labatut Committed by Facebook GitHub Bot
Browse files

Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
parent eb512ffd
...@@ -4,23 +4,17 @@ ...@@ -4,23 +4,17 @@
""" """
Sanity checks for output images from the renderer. Sanity checks for output images from the renderer.
""" """
import numpy as np
import unittest import unittest
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from pytorch3d.io import load_objs_as_meshes from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import ( from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
OpenGLPerspectiveCameras,
look_at_view_transform,
)
from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.rasterizer import ( from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
MeshRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import ( from pytorch3d.renderer.mesh.shader import (
BlendParams, BlendParams,
...@@ -34,6 +28,7 @@ from pytorch3d.renderer.mesh.texturing import Textures ...@@ -34,6 +28,7 @@ from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
...@@ -65,9 +60,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -65,9 +60,7 @@ class TestRenderingMeshes(unittest.TestCase):
verts_padded = sphere_mesh.verts_padded() verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded() faces_padded = sphere_mesh.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded)) textures = Textures(verts_rgb=torch.ones_like(verts_padded))
sphere_mesh = Meshes( sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
verts=verts_padded, faces=faces_padded, textures=textures
)
# Init rasterizer settings # Init rasterizer settings
if elevated_camera: if elevated_camera:
...@@ -90,9 +83,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -90,9 +83,7 @@ class TestRenderingMeshes(unittest.TestCase):
raster_settings = RasterizationSettings( raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0 image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
) )
rasterizer = MeshRasterizer( rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
cameras=cameras, raster_settings=raster_settings
)
# Test several shaders # Test several shaders
shaders = { shaders = {
...@@ -101,9 +92,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -101,9 +92,7 @@ class TestRenderingMeshes(unittest.TestCase):
"flat": HardFlatShader, "flat": HardFlatShader,
} }
for (name, shader_init) in shaders.items(): for (name, shader_init) in shaders.items():
shader = shader_init( shader = shader_init(lights=lights, cameras=cameras, materials=materials)
lights=lights, cameras=cameras, materials=materials
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh) images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix) filename = "simple_sphere_light_%s%s.png" % (name, postfix)
...@@ -125,9 +114,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -125,9 +114,7 @@ class TestRenderingMeshes(unittest.TestCase):
phong_shader = HardPhongShader( phong_shader = HardPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
) )
phong_renderer = MeshRenderer( phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
rasterizer=rasterizer, shader=phong_shader
)
images = phong_renderer(sphere_mesh, lights=lights) images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu() rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG: if DEBUG:
...@@ -137,9 +124,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -137,9 +124,7 @@ class TestRenderingMeshes(unittest.TestCase):
) )
# Load reference image # Load reference image
image_ref_phong_dark = load_rgb_image( image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
"test_simple_sphere_dark%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05)) self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
def test_simple_sphere_elevated_camera(self): def test_simple_sphere_elevated_camera(self):
...@@ -184,18 +169,14 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -184,18 +169,14 @@ class TestRenderingMeshes(unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
# Init renderer # Init renderer
rasterizer = MeshRasterizer( rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
cameras=cameras, raster_settings=raster_settings
)
shaders = { shaders = {
"phong": HardGouraudShader, "phong": HardGouraudShader,
"gouraud": HardGouraudShader, "gouraud": HardGouraudShader,
"flat": HardFlatShader, "flat": HardFlatShader,
} }
for (name, shader_init) in shaders.items(): for (name, shader_init) in shaders.items():
shader = shader_init( shader = shader_init(lights=lights, cameras=cameras, materials=materials)
lights=lights, cameras=cameras, materials=materials
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes) images = renderer(sphere_meshes)
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name) image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
...@@ -228,9 +209,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -228,9 +209,7 @@ class TestRenderingMeshes(unittest.TestCase):
# Init renderer # Init renderer
renderer = MeshRenderer( renderer = MeshRenderer(
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
cameras=cameras, raster_settings=raster_settings
),
shader=SoftSilhouetteShader(blend_params=blend_params), shader=SoftSilhouetteShader(blend_params=blend_params),
) )
images = renderer(sphere_mesh) images = renderer(sphere_mesh)
...@@ -258,9 +237,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -258,9 +237,7 @@ class TestRenderingMeshes(unittest.TestCase):
The pupils in the eyes of the cow should always be looking to the left. The pupils in the eyes of the cow should always be looking to the left.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
DATA_DIR = ( DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
)
obj_filename = DATA_DIR / "cow_mesh/cow.obj" obj_filename = DATA_DIR / "cow_mesh/cow.obj"
# Load mesh + texture # Load mesh + texture
...@@ -283,9 +260,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -283,9 +260,7 @@ class TestRenderingMeshes(unittest.TestCase):
# Init renderer # Init renderer
renderer = MeshRenderer( renderer = MeshRenderer(
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
cameras=cameras, raster_settings=raster_settings
),
shader=TexturedSoftPhongShader( shader=TexturedSoftPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
), ),
...@@ -306,9 +281,7 @@ class TestRenderingMeshes(unittest.TestCase): ...@@ -306,9 +281,7 @@ class TestRenderingMeshes(unittest.TestCase):
# Check grad exists # Check grad exists
[verts] = mesh.verts_list() [verts] = mesh.verts_list()
verts.requires_grad = True verts.requires_grad = True
mesh2 = Meshes( mesh2 = Meshes(verts=[verts], faces=mesh.faces_list(), textures=mesh.textures)
verts=[verts], faces=mesh.faces_list(), textures=mesh.textures
)
images = renderer(mesh2) images = renderer(mesh2)
images[0, ...].sum().backward() images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad) self.assertIsNotNone(verts.grad)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import unittest import unittest
import torch
from pytorch3d.renderer.utils import TensorProperties
import numpy as np
import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.renderer.utils import TensorProperties
# Example class for testing # Example class for testing
...@@ -81,9 +80,5 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase): ...@@ -81,9 +80,5 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
if inds.sum() > 0: if inds.sum() > 0:
# Check the gathered points in the output have the same value from # Check the gathered points in the output have the same value from
# the input. # the input.
self.assertClose( self.assertClose(test_class_gathered.x[inds].mean(dim=0), x[i, ...])
test_class_gathered.x[inds].mean(dim=0), x[i, ...] self.assertClose(test_class_gathered.y[inds].mean(dim=0), y[i, ...])
)
self.assertClose(
test_class_gathered.y[inds].mean(dim=0), y[i, ...]
)
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import itertools import itertools
import math import math
import unittest import unittest
import torch
import torch
from pytorch3d.transforms.rotation_conversions import ( from pytorch3d.transforms.rotation_conversions import (
euler_angles_to_matrix, euler_angles_to_matrix,
matrix_to_euler_angles, matrix_to_euler_angles,
...@@ -45,9 +45,7 @@ class TestRandomRotation(unittest.TestCase): ...@@ -45,9 +45,7 @@ class TestRandomRotation(unittest.TestCase):
) )
# The 0.1 significance level for chisquare(8-1) is # The 0.1 significance level for chisquare(8-1) is
# scipy.stats.chi2(7).ppf(0.9) == 12.017. # scipy.stats.chi2(7).ppf(0.9) == 12.017.
self.assertLess( self.assertLess(chisquare_statistic, 12, (counts, chisquare_statistic, k))
chisquare_statistic, 12, (counts, chisquare_statistic, k)
)
class TestRotationConversion(unittest.TestCase): class TestRotationConversion(unittest.TestCase):
......
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
import unittest import unittest
from pathlib import Path from pathlib import Path
import torch
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import sample_points_from_meshes from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSamplePoints(TestCaseMixin, unittest.TestCase): class TestSamplePoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
...@@ -28,9 +27,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -28,9 +27,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
for _ in range(num_meshes): for _ in range(num_meshes):
verts = torch.rand( verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint( faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
) )
...@@ -48,13 +45,9 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -48,13 +45,9 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
device = torch.device("cuda:0") device = torch.device("cuda:0")
verts1 = torch.tensor([], dtype=torch.float32, device=device) verts1 = torch.tensor([], dtype=torch.float32, device=device)
faces1 = torch.tensor([], dtype=torch.int64, device=device) faces1 = torch.tensor([], dtype=torch.int64, device=device)
meshes = Meshes( meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1]
)
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
sample_points_from_meshes( sample_points_from_meshes(meshes, num_samples=100, return_normals=True)
meshes, num_samples=100, return_normals=True
)
self.assertTrue("Meshes are empty." in str(err.exception)) self.assertTrue("Meshes are empty." in str(err.exception))
def test_sampling_output(self): def test_sampling_output(self):
...@@ -67,12 +60,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -67,12 +60,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
# Unit simplex. # Unit simplex.
verts_pyramid = torch.tensor( verts_pyramid = torch.tensor(
[ [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
],
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
...@@ -113,12 +101,8 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -113,12 +101,8 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
pyramid_verts = samples[2, :] pyramid_verts = samples[2, :]
pyramid_normals = normals[2, :] pyramid_normals = normals[2, :]
self.assertClose( self.assertClose(pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts))
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts) self.assertClose((pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts))
)
self.assertClose(
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
)
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1). # Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
face_1_idxs = pyramid_verts[:, 2] == 0 face_1_idxs = pyramid_verts[:, 2] == 0
...@@ -126,14 +110,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -126,14 +110,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
pyramid_verts[face_1_idxs, :], pyramid_verts[face_1_idxs, :],
pyramid_normals[face_1_idxs, :], pyramid_normals[face_1_idxs, :],
) )
self.assertTrue( self.assertTrue(torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1))
torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1)
)
self.assertClose( self.assertClose(
face_1_normals, face_1_normals,
torch.tensor([0, 0, 1], dtype=torch.float32).expand( torch.tensor([0, 0, 1], dtype=torch.float32).expand(face_1_normals.size()),
face_1_normals.size()
),
) )
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0). # Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
...@@ -142,14 +122,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -142,14 +122,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
pyramid_verts[face_2_idxs, :], pyramid_verts[face_2_idxs, :],
pyramid_normals[face_2_idxs, :], pyramid_normals[face_2_idxs, :],
) )
self.assertTrue( self.assertTrue(torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1))
torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1)
)
self.assertClose( self.assertClose(
face_2_normals, face_2_normals,
torch.tensor([1, 0, 0], dtype=torch.float32).expand( torch.tensor([1, 0, 0], dtype=torch.float32).expand(face_2_normals.size()),
face_2_normals.size()
),
) )
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0). # Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
...@@ -158,14 +134,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -158,14 +134,10 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
pyramid_verts[face_3_idxs, :], pyramid_verts[face_3_idxs, :],
pyramid_normals[face_3_idxs, :], pyramid_normals[face_3_idxs, :],
) )
self.assertTrue( self.assertTrue(torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1))
torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1)
)
self.assertClose( self.assertClose(
face_3_normals, face_3_normals,
torch.tensor([0, -1, 0], dtype=torch.float32).expand( torch.tensor([0, -1, 0], dtype=torch.float32).expand(face_3_normals.size()),
face_3_normals.size()
),
) )
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3). # Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
...@@ -279,22 +251,15 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -279,22 +251,15 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
num_faces = 50 num_faces = 50
for device in ["cpu", "cuda:0"]: for device in ["cpu", "cuda:0"]:
for invalid in ["nan", "inf"]: for invalid in ["nan", "inf"]:
verts = torch.rand( verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
(num_verts, 3), dtype=torch.float32, device=device
)
# randomly assign an invalid type # randomly assign an invalid type
verts[torch.randperm(num_verts)[:10]] = float(invalid) verts[torch.randperm(num_verts)[:10]] = float(invalid)
faces = torch.randint( faces = torch.randint(
num_verts, num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
size=(num_faces, 3),
dtype=torch.int64,
device=device,
) )
meshes = Meshes(verts=[verts], faces=[faces]) meshes = Meshes(verts=[verts], faces=[faces])
with self.assertRaisesRegex( with self.assertRaisesRegex(ValueError, "Meshes contain nan or inf."):
ValueError, "Meshes contain nan or inf."
):
sample_points_from_meshes( sample_points_from_meshes(
meshes, num_samples=100, return_normals=True meshes, num_samples=100, return_normals=True
) )
...@@ -310,9 +275,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): ...@@ -310,9 +275,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
for _ in range(num_meshes): for _ in range(num_meshes):
verts = torch.rand( verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint( faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
) )
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import unittest import unittest
import torch
import numpy as np
import torch
from pytorch3d.transforms.so3 import ( from pytorch3d.transforms.so3 import (
hat, hat,
so3_exponential_map, so3_exponential_map,
...@@ -26,9 +26,7 @@ class TestSO3(unittest.TestCase): ...@@ -26,9 +26,7 @@ class TestSO3(unittest.TestCase):
randomly generated logarithms of rotation matrices. randomly generated logarithms of rotation matrices.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
log_rot = torch.randn( log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
(batch_size, 3), dtype=torch.float32, device=device
)
return log_rot return log_rot
@staticmethod @staticmethod
...@@ -85,16 +83,12 @@ class TestSO3(unittest.TestCase): ...@@ -85,16 +83,12 @@ class TestSO3(unittest.TestCase):
log_rot = torch.randn(size=[5, 4], device=device) log_rot = torch.randn(size=[5, 4], device=device)
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
so3_exponential_map(log_rot) so3_exponential_map(log_rot)
self.assertTrue( self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
"Input tensor shape has to be Nx3." in str(err.exception)
)
rot = torch.randn(size=[5, 3, 5], device=device) rot = torch.randn(size=[5, 3, 5], device=device)
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
so3_log_map(rot) so3_log_map(rot)
self.assertTrue( self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
"Input has to be a batch of 3x3 Tensors." in str(err.exception)
)
# trace of rot definitely bigger than 3 or smaller than -1 # trace of rot definitely bigger than 3 or smaller than -1
rot = torch.cat( rot = torch.cat(
......
...@@ -2,11 +2,10 @@ ...@@ -2,11 +2,10 @@
import unittest import unittest
import torch
from pytorch3d.structures import utils as struct_utils
import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.structures import utils as struct_utils
class TestStructUtils(TestCaseMixin, unittest.TestCase): class TestStructUtils(TestCaseMixin, unittest.TestCase):
...@@ -27,22 +26,16 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -27,22 +26,16 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
self.assertEqual(x_padded.shape[1], K) self.assertEqual(x_padded.shape[1], K)
self.assertEqual(x_padded.shape[2], K) self.assertEqual(x_padded.shape[2], K)
for i in range(N): for i in range(N):
self.assertClose( self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i])
x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]
)
# check for no pad size (defaults to max dimension) # check for no pad size (defaults to max dimension)
x_padded = struct_utils.list_to_padded( x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False)
x, pad_value=0.0, equisized=False
)
max_size0 = max(y.shape[0] for y in x) max_size0 = max(y.shape[0] for y in x)
max_size1 = max(y.shape[1] for y in x) max_size1 = max(y.shape[1] for y in x)
self.assertEqual(x_padded.shape[1], max_size0) self.assertEqual(x_padded.shape[1], max_size0)
self.assertEqual(x_padded.shape[2], max_size1) self.assertEqual(x_padded.shape[2], max_size1)
for i in range(N): for i in range(N):
self.assertClose( self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i])
x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]
)
# check for equisized # check for equisized
x = [torch.rand((K, 10), device=device) for _ in range(N)] x = [torch.rand((K, 10), device=device) for _ in range(N)]
...@@ -88,9 +81,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -88,9 +81,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0) split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
x_list = struct_utils.padded_to_list(x, split_size) x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N): for i in range(N):
self.assertClose( self.assertClose(x_list[i], x[i, : split_size[i][0], : split_size[i][1]])
x_list[i], x[i, : split_size[i][0], : split_size[i][1]]
)
with self.assertRaisesRegex(ValueError, "Supports only"): with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device) x = torch.rand((N, K, K, K, K), device=device)
...@@ -124,32 +115,24 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -124,32 +115,24 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
# Add some random values in the input which are the same as the pad_value. # Add some random values in the input which are the same as the pad_value.
# These should not be filtered out. # These should not be filtered out.
x_list.append( x_list.append(
torch.randint( torch.randint(low=pad_value, high=10, size=(dim, K), device=device)
low=pad_value, high=10, size=(dim, K), device=device
)
) )
split_size.append(dim) split_size.append(dim)
x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value) x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value)
x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value) x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value)
curr = 0 curr = 0
for i in range(N): for i in range(N):
self.assertClose( self.assertClose(x_packed[curr : curr + split_size[i], ...], x_list[i])
x_packed[curr : curr + split_size[i], ...], x_list[i]
)
self.assertClose(torch.cat(x_list), x_packed) self.assertClose(torch.cat(x_list), x_packed)
curr += split_size[i] curr += split_size[i]
# Case 3: split_size is provided. # Case 3: split_size is provided.
# Check each section of the packed tensor matches the corresponding # Check each section of the packed tensor matches the corresponding
# unpadded elements. # unpadded elements.
x_packed = struct_utils.padded_to_packed( x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)
x_padded, split_size=split_size
)
curr = 0 curr = 0
for i in range(N): for i in range(N):
self.assertClose( self.assertClose(x_packed[curr : curr + split_size[i], ...], x_list[i])
x_packed[curr : curr + split_size[i], ...], x_list[i]
)
self.assertClose(torch.cat(x_list), x_packed) self.assertClose(torch.cat(x_list), x_packed)
curr += split_size[i] curr += split_size[i]
...@@ -157,17 +140,13 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -157,17 +140,13 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
# Raise an error. # Raise an error.
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0) split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
with self.assertRaisesRegex(ValueError, "1-dimensional"): with self.assertRaisesRegex(ValueError, "1-dimensional"):
x_packed = struct_utils.padded_to_packed( x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)
x_padded, split_size=split_size
)
split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist() split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist()
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "same length as inputs first dimension" ValueError, "same length as inputs first dimension"
): ):
x_packed = struct_utils.padded_to_packed( x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)
x_padded, split_size=split_size
)
# Case 5: both pad_value and split_size are provided. # Case 5: both pad_value and split_size are provided.
# Raise an error. # Raise an error.
...@@ -204,8 +183,6 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): ...@@ -204,8 +183,6 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
for i in range(N): for i in range(N):
self.assertTrue(num_items[i] == x_dims[i]) self.assertTrue(num_items[i] == x_dims[i])
self.assertTrue(item_packed_first_idx[i] == cur) self.assertTrue(item_packed_first_idx[i] == cur)
self.assertTrue( self.assertTrue(item_packed_to_list_idx[cur : cur + x_dims[i]].eq(i).all())
item_packed_to_list_idx[cur : cur + x_dims[i]].eq(i).all()
)
self.assertClose(x_packed[cur : cur + x_dims[i]], x[i]) self.assertClose(x_packed[cur : cur + x_dims[i]], x[i])
cur += x_dims[i] cur += x_dims[i]
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
import unittest import unittest
import torch
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops.subdivide_meshes import SubdivideMeshes from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase): class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_subdivide(self): def test_simple_subdivide(self):
...@@ -72,25 +71,14 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase): ...@@ -72,25 +71,14 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
) )
faces1 = torch.tensor([[0, 1, 2]], dtype=torch.int64, device=device) faces1 = torch.tensor([[0, 1, 2]], dtype=torch.int64, device=device)
verts2 = torch.tensor( verts2 = torch.tensor(
[ [[0.5, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 1.0, 0.0]],
[0.5, 1.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[1.5, 1.0, 0.0],
],
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
requires_grad=True, requires_grad=True,
) )
faces2 = torch.tensor( faces2 = torch.tensor([[0, 1, 2], [0, 3, 1]], dtype=torch.int64, device=device)
[[0, 1, 2], [0, 3, 1]], dtype=torch.int64, device=device faces3 = torch.tensor([[0, 1, 2], [0, 2, 3]], dtype=torch.int64, device=device)
) mesh = Meshes(verts=[verts1, verts2, verts2], faces=[faces1, faces2, faces3])
faces3 = torch.tensor(
[[0, 1, 2], [0, 2, 3]], dtype=torch.int64, device=device
)
mesh = Meshes(
verts=[verts1, verts2, verts2], faces=[faces1, faces2, faces3]
)
subdivide = SubdivideMeshes() subdivide = SubdivideMeshes()
new_mesh = subdivide(mesh.clone()) new_mesh = subdivide(mesh.clone())
...@@ -218,9 +206,7 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase): ...@@ -218,9 +206,7 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad) self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
@staticmethod @staticmethod
def subdivide_meshes_with_init( def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
num_meshes: int = 10, same_topo: bool = False
):
device = torch.device("cuda:0") device = torch.device("cuda:0")
meshes = ico_sphere(0, device=device) meshes = ico_sphere(0, device=device)
if num_meshes > 1: if num_meshes > 1:
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import unittest import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import ( from pytorch3d.renderer.mesh.texturing import (
interpolate_face_attributes, interpolate_face_attributes,
...@@ -13,8 +14,6 @@ from pytorch3d.renderer.mesh.texturing import ( ...@@ -13,8 +14,6 @@ from pytorch3d.renderer.mesh.texturing import (
) )
from pytorch3d.structures import Meshes, Textures from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures.utils import list_to_padded from pytorch3d.structures.utils import list_to_padded
from common_testing import TestCaseMixin
from test_meshes import TestMeshes from test_meshes import TestMeshes
...@@ -68,12 +67,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -68,12 +67,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=torch.ones_like(pix_to_face), dists=torch.ones_like(pix_to_face),
) )
grad_vert_tex = torch.tensor( grad_vert_tex = torch.tensor(
[ [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
[0.3, 0.3, 0.3],
[0.9, 0.9, 0.9],
[0.5, 0.5, 0.5],
[0.3, 0.3, 0.3],
],
dtype=torch.float32, dtype=torch.float32,
) )
texels = interpolate_vertex_colors(fragments, mesh) texels = interpolate_vertex_colors(fragments, mesh)
...@@ -115,9 +109,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -115,9 +109,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1) ).view(1, 1, 1, 2, -1)
dummy_verts = torch.zeros(4, 3) dummy_verts = torch.zeros(4, 3)
vert_uvs = torch.tensor( vert_uvs = torch.tensor([[1, 0], [0, 1], [1, 1], [0, 0]], dtype=torch.float32)
[[1, 0], [0, 1], [1, 1], [0, 0]], dtype=torch.float32
)
face_uvs = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.int64) face_uvs = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.int64)
interpolated_uvs = torch.tensor( interpolated_uvs = torch.tensor(
[[0.5 + 0.2, 0.3 + 0.2], [0.6, 0.3 + 0.6]], dtype=torch.float32 [[0.5 + 0.2, 0.3 + 0.2], [0.6, 0.3 + 0.6]], dtype=torch.float32
...@@ -137,9 +129,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -137,9 +129,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=pix_to_face, dists=pix_to_face,
) )
tex = Textures( tex = Textures(
maps=tex_map, maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...]
faces_uvs=face_uvs[None, ...],
verts_uvs=vert_uvs[None, ...],
) )
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex) meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
texels = interpolate_texture_map(fragments, meshes) texels = interpolate_texture_map(fragments, meshes)
...@@ -151,9 +141,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -151,9 +141,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
tex_map = tex_map.permute(0, 3, 1, 2) tex_map = tex_map.permute(0, 3, 1, 2)
tex_map = torch.cat([tex_map, tex_map], dim=0) tex_map = torch.cat([tex_map, tex_map], dim=0)
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False) expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
self.assertTrue( self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
torch.allclose(texels.squeeze(), expected_out.squeeze())
)
def test_init_rgb_uv_fail(self): def test_init_rgb_uv_fail(self):
V = 20 V = 20
...@@ -183,9 +171,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -183,9 +171,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
Textures(verts_rgb=torch.ones((5, 16, 16, 3))) Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
# maps provided without verts/faces uvs # maps provided without verts/faces uvs
with self.assertRaisesRegex( with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"):
ValueError, "faces_uvs and verts_uvs are required"
):
Textures(maps=torch.ones((5, 16, 16, 3))) Textures(maps=torch.ones((5, 16, 16, 3)))
def test_padded_to_packed(self): def test_padded_to_packed(self):
...@@ -209,9 +195,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -209,9 +195,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
# This is set inside Meshes when textures is passed as an input. # This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone() tex1 = tex.clone()
tex1._num_faces_per_mesh = ( tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
)
tex1._num_verts_per_mesh = torch.tensor([5, 4]) tex1._num_verts_per_mesh = torch.tensor([5, 4])
faces_packed = tex1.faces_uvs_packed() faces_packed = tex1.faces_uvs_packed()
verts_packed = tex1.verts_uvs_packed() verts_packed = tex1.verts_uvs_packed()
...@@ -245,16 +229,12 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -245,16 +229,12 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
for i in range(N): for i in range(N):
self.assertTrue( self.assertTrue(
(faces_list[i] == faces_uvs_padded[i, ...].squeeze()) (faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item()
.all()
.item()
) )
for i in range(N): for i in range(N):
self.assertTrue( self.assertTrue(
(verts_list[i] == verts_uvs_padded[i, ...].squeeze()) (verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item()
.all()
.item()
) )
def test_clone(self): def test_clone(self):
...@@ -344,9 +324,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -344,9 +324,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
verts_uvs=torch.randn((B, V, 2)), verts_uvs=torch.randn((B, V, 2)),
) )
tex_mesh = Meshes( tex_mesh = Meshes(
verts=mesh.verts_padded(), verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
faces=mesh.faces_padded(),
textures=tex_uv,
) )
N = 20 N = 20
new_mesh = tex_mesh.extend(N) new_mesh = tex_mesh.extend(N)
...@@ -359,12 +337,10 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -359,12 +337,10 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
for i in range(len(tex_mesh)): for i in range(len(tex_mesh)):
for n in range(N): for n in range(N):
self.assertClose( self.assertClose(
tex_init.faces_uvs_list()[i], tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
new_tex.faces_uvs_list()[i * N + n],
) )
self.assertClose( self.assertClose(
tex_init.verts_uvs_list()[i], tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
new_tex.verts_uvs_list()[i * N + n],
) )
self.assertAllSeparate( self.assertAllSeparate(
[ [
...@@ -384,9 +360,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -384,9 +360,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
# 2. Texture vertex RGB # 2. Texture vertex RGB
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3))) tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
tex_mesh_rgb = Meshes( tex_mesh_rgb = Meshes(
verts=mesh.verts_padded(), verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb
faces=mesh.faces_padded(),
textures=tex_rgb,
) )
N = 20 N = 20
new_mesh_rgb = tex_mesh_rgb.extend(N) new_mesh_rgb = tex_mesh_rgb.extend(N)
...@@ -399,8 +373,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): ...@@ -399,8 +373,7 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
for i in range(len(tex_mesh_rgb)): for i in range(len(tex_mesh_rgb)):
for n in range(N): for n in range(N):
self.assertClose( self.assertClose(
tex_init.verts_rgb_list()[i], tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n]
new_tex.verts_rgb_list()[i * N + n],
) )
self.assertAllSeparate( self.assertAllSeparate(
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()] [tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import math import math
import unittest import unittest
import torch
import torch
from pytorch3d.transforms.so3 import so3_exponential_map from pytorch3d.transforms.so3 import so3_exponential_map
from pytorch3d.transforms.transform3d import ( from pytorch3d.transforms.transform3d import (
Rotate, Rotate,
...@@ -18,9 +18,7 @@ from pytorch3d.transforms.transform3d import ( ...@@ -18,9 +18,7 @@ from pytorch3d.transforms.transform3d import (
class TestTransform(unittest.TestCase): class TestTransform(unittest.TestCase):
def test_to(self): def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor( R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]
)
R = Rotate(R) R = Rotate(R)
t = Transform3d().compose(R, tr) t = Transform3d().compose(R, tr)
for _ in range(3): for _ in range(3):
...@@ -36,9 +34,7 @@ class TestTransform(unittest.TestCase): ...@@ -36,9 +34,7 @@ class TestTransform(unittest.TestCase):
the same as composition of clones of translation and rotation. the same as composition of clones of translation and rotation.
""" """
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor( R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]
)
R = Rotate(R) R = Rotate(R)
# check that the _matrix property of clones of # check that the _matrix property of clones of
...@@ -63,9 +59,9 @@ class TestTransform(unittest.TestCase): ...@@ -63,9 +59,9 @@ class TestTransform(unittest.TestCase):
def test_translate(self): def test_translate(self):
t = Transform3d().translate(1, 2, 3) t = Transform3d().translate(1, 2, 3)
points = torch.tensor( points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]] 1, 3, 3
).view(1, 3, 3) )
normals = torch.tensor( normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3) ).view(1, 3, 3)
...@@ -82,9 +78,9 @@ class TestTransform(unittest.TestCase): ...@@ -82,9 +78,9 @@ class TestTransform(unittest.TestCase):
def test_scale(self): def test_scale(self):
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0) t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
points = torch.tensor( points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]] 1, 3, 3
).view(1, 3, 3) )
normals = torch.tensor( normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3) ).view(1, 3, 3)
...@@ -101,9 +97,9 @@ class TestTransform(unittest.TestCase): ...@@ -101,9 +97,9 @@ class TestTransform(unittest.TestCase):
def test_scale_translate(self): def test_scale_translate(self):
t = Transform3d().scale(2, 1, 3).translate(1, 2, 3) t = Transform3d().scale(2, 1, 3).translate(1, 2, 3)
points = torch.tensor( points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]] 1, 3, 3
).view(1, 3, 3) )
normals = torch.tensor( normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3) ).view(1, 3, 3)
...@@ -120,9 +116,9 @@ class TestTransform(unittest.TestCase): ...@@ -120,9 +116,9 @@ class TestTransform(unittest.TestCase):
def test_rotate_axis_angle(self): def test_rotate_axis_angle(self):
t = Transform3d().rotate_axis_angle(90.0, axis="Z") t = Transform3d().rotate_axis_angle(90.0, axis="Z")
points = torch.tensor( points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]).view(
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]] 1, 3, 3
).view(1, 3, 3) )
normals = torch.tensor( normals = torch.tensor(
[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]] [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]
).view(1, 3, 3) ).view(1, 3, 3)
...@@ -194,9 +190,7 @@ class TestTransform(unittest.TestCase): ...@@ -194,9 +190,7 @@ class TestTransform(unittest.TestCase):
t_ = Rotate( t_ = Rotate(
so3_exponential_map( so3_exponential_map(
torch.randn( torch.randn(
(batch_size, 3), (batch_size, 3), dtype=torch.float32, device=device
dtype=torch.float32,
device=device,
) )
), ),
device=device, device=device,
...@@ -717,9 +711,7 @@ class TestRotate(unittest.TestCase): ...@@ -717,9 +711,7 @@ class TestRotate(unittest.TestCase):
def test_inverse(self, batch_size=5): def test_inverse(self, batch_size=5):
device = torch.device("cuda:0") device = torch.device("cuda:0")
log_rot = torch.randn( log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
(batch_size, 3), dtype=torch.float32, device=device
)
R = so3_exponential_map(log_rot) R = so3_exponential_map(log_rot)
t = Rotate(R) t = Rotate(R)
im = t.inverse()._matrix im = t.inverse()._matrix
...@@ -749,9 +741,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -749,9 +741,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0]) expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix)) self.assertTrue(torch.allclose(t._matrix, matrix))
...@@ -775,9 +765,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -775,9 +765,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0]) expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
...@@ -835,9 +823,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -835,9 +823,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0]) expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
...@@ -866,9 +852,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -866,9 +852,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0]) expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
...@@ -923,9 +907,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -923,9 +907,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0]) expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
...@@ -949,9 +931,7 @@ class TestRotateAxisAngle(unittest.TestCase): ...@@ -949,9 +931,7 @@ class TestRotateAxisAngle(unittest.TestCase):
transformed_points = t.transform_points(points) transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0]) expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
transformed_points.squeeze(), expected_points, atol=1e-7
)
) )
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
......
...@@ -2,22 +2,18 @@ ...@@ -2,22 +2,18 @@
import unittest import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.ops.vert_align import vert_align from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from common_testing import TestCaseMixin
class TestVertAlign(TestCaseMixin, unittest.TestCase): class TestVertAlign(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def vert_align_naive( def vert_align_naive(
feats, feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True
verts_or_meshes,
return_packed: bool = False,
align_corners: bool = True,
): ):
""" """
Naive implementation of vert_align. Naive implementation of vert_align.
...@@ -60,16 +56,13 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): ...@@ -60,16 +56,13 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return out_feats return out_feats
@staticmethod @staticmethod
def init_meshes( def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
):
device = torch.device("cuda:0") device = torch.device("cuda:0")
verts_list = [] verts_list = []
faces_list = [] faces_list = []
for _ in range(num_meshes): for _ in range(num_meshes):
verts = ( verts = (
torch.rand((num_verts, 3), dtype=torch.float32, device=device) torch.rand((num_verts, 3), dtype=torch.float32, device=device) * 2.0
* 2.0
- 1.0 - 1.0
) # verts in the space of [-1, 1] ) # verts in the space of [-1, 1]
faces = torch.randint( faces = torch.randint(
...@@ -82,15 +75,11 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): ...@@ -82,15 +75,11 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return meshes return meshes
@staticmethod @staticmethod
def init_feats( def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
batch_size: int = 10, num_channels: int = 256, device: str = "cuda"
):
H, W = [14, 28], [14, 28] H, W = [14, 28], [14, 28]
feats = [] feats = []
for (h, w) in zip(H, W): for (h, w) in zip(H, W):
feats.append( feats.append(torch.rand((batch_size, num_channels, h, w), device=device))
torch.rand((batch_size, num_channels, h, w), device=device)
)
return feats return feats
def test_vert_align_with_meshes(self): def test_vert_align_with_meshes(self):
...@@ -102,16 +91,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): ...@@ -102,16 +91,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
# feats in list # feats in list
out = vert_align(feats, meshes, return_packed=True) out = vert_align(feats, meshes, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(feats, meshes, return_packed=True)
feats, meshes, return_packed=True
)
self.assertClose(out, naive_out) self.assertClose(out, naive_out)
# feats as tensor # feats as tensor
out = vert_align(feats[0], meshes, return_packed=True) out = vert_align(feats[0], meshes, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
feats[0], meshes, return_packed=True
)
self.assertClose(out, naive_out) self.assertClose(out, naive_out)
def test_vert_align_with_verts(self): def test_vert_align_with_verts(self):
...@@ -120,30 +105,21 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): ...@@ -120,30 +105,21 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
""" """
feats = TestVertAlign.init_feats(10, 256) feats = TestVertAlign.init_feats(10, 256)
verts = ( verts = (
torch.rand( torch.rand((10, 100, 3), dtype=torch.float32, device=feats[0].device) * 2.0
(10, 100, 3), dtype=torch.float32, device=feats[0].device
)
* 2.0
- 1.0 - 1.0
) )
# feats in list # feats in list
out = vert_align(feats, verts, return_packed=True) out = vert_align(feats, verts, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(feats, verts, return_packed=True)
feats, verts, return_packed=True
)
self.assertClose(out, naive_out) self.assertClose(out, naive_out)
# feats as tensor # feats as tensor
out = vert_align(feats[0], verts, return_packed=True) out = vert_align(feats[0], verts, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(feats[0], verts, return_packed=True)
feats[0], verts, return_packed=True
)
self.assertClose(out, naive_out) self.assertClose(out, naive_out)
out2 = vert_align( out2 = vert_align(feats[0], verts, return_packed=True, align_corners=False)
feats[0], verts, return_packed=True, align_corners=False
)
naive_out2 = TestVertAlign.vert_align_naive( naive_out2 = TestVertAlign.vert_align_naive(
feats[0], verts, return_packed=True, align_corners=False feats[0], verts, return_packed=True, align_corners=False
) )
...@@ -158,9 +134,7 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): ...@@ -158,9 +134,7 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
for _ in range(num_meshes): for _ in range(num_meshes):
verts = torch.rand( verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint( faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment