Commit 2bbca5f2 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Allow setting verts_normals on Meshes

Summary: Add ability to set the vertex normals when creating a Meshes, so that the pluggable loaders can return them from a file.

Reviewed By: nikhilaravi

Differential Revision: D27765258

fbshipit-source-id: b5ddaa00de3707f636f94d9f74d1da12ecce0608
parent 502f15ac
......@@ -207,7 +207,14 @@ class Meshes(object):
"equisized",
]
def __init__(self, verts=None, faces=None, textures=None):
def __init__(
self,
verts=None,
faces=None,
textures=None,
*,
verts_normals=None,
):
"""
Args:
verts:
......@@ -229,6 +236,17 @@ class Meshes(object):
the same number of faces.
textures: Optional instance of the Textures class with mesh
texture properties.
verts_normals:
Optional. Can be either
- List where each element is a tensor of shape (num_verts, 3)
containing the normals of each vertex.
- Padded float tensor with shape (num_meshes, max_num_verts, 3).
They should be padded with fill value of 0 so they all have
the same number of vertices.
Note that modifying the mesh later, e.g. with offset_verts_,
can cause these normals to be forgotten and normals to be recalculated
based on the new vertex positions.
Refer to comments above for descriptions of List and Padded representations.
"""
......@@ -354,8 +372,8 @@ class Meshes(object):
self.equisized = True
elif torch.is_tensor(verts) and torch.is_tensor(faces):
if verts.size(2) != 3 and faces.size(2) != 3:
raise ValueError("Verts and Faces tensors have incorrect dimensions.")
if verts.size(2) != 3 or faces.size(2) != 3:
raise ValueError("Verts or Faces tensors have incorrect dimensions.")
self._verts_padded = verts
self._faces_padded = faces.to(torch.int64)
self._N = self._verts_padded.shape[0]
......@@ -412,6 +430,36 @@ class Meshes(object):
self.textures._N = self._N
self.textures.valid = self.valid
if verts_normals is not None:
self._set_verts_normals(verts_normals)
def _set_verts_normals(self, verts_normals) -> None:
if isinstance(verts_normals, list):
if len(verts_normals) != self._N:
raise ValueError("Invalid verts_normals input")
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
if (
not isinstance(item, torch.Tensor)
or item.ndim != 2
or item.shape[1] != 3
or item.shape[0] != n_verts
):
raise ValueError("Invalid verts_normals input")
self._verts_normals_packed = torch.cat(verts_normals, 0)
elif torch.is_tensor(verts_normals):
if (
verts_normals.ndim != 3
or verts_normals.size(2) != 3
or verts_normals.size(0) != self._N
):
raise ValueError("Vertex normals tensor has incorrect dimensions.")
self._verts_normals_packed = struct_utils.padded_to_packed(
verts_normals, split_size=self._num_verts_per_mesh.tolist()
)
else:
raise ValueError("verts_normals must be a list or tensor")
def __len__(self):
return self._N
......@@ -1253,6 +1301,7 @@ class Meshes(object):
def offset_verts_(self, vert_offsets_packed):
"""
Add an offset to the vertices of this Meshes. In place operation.
If normals are present they may be recalculated.
Args:
vert_offsets_packed: A Tensor of shape (3,) or the same shape as
......@@ -1286,7 +1335,7 @@ class Meshes(object):
self._verts_padded[i, : verts.shape[0], :] = verts
# update face areas and normals and vertex normals
# only if the original attributes are computed
# only if the original attributes are present
if update_normals and any(
v is not None
for v in [self._faces_areas_packed, self._faces_normals_packed]
......
......@@ -223,7 +223,7 @@ class TestMeshNormalConsistency(unittest.TestCase):
Test Mesh Normal Consistency for a mesh known to have no
intersecting faces.
"""
verts = torch.rand(1, 6, 2)
verts = torch.rand(1, 6, 3)
faces = torch.arange(6).reshape(1, 2, 3)
meshes = Meshes(verts=verts, faces=faces)
out = mesh_normal_consistency(meshes)
......
......@@ -1138,6 +1138,20 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(meshes.faces_normals_padded().shape[0], 0)
self.assertEqual(meshes.faces_normals_list(), [])
def test_assigned_normals(self):
verts = torch.rand(2, 6, 3)
faces = torch.randint(6, size=(2, 4, 3))
for verts_normals in [list(verts.unbind(0)), verts]:
yes_normals = Meshes(
verts=verts.clone(), faces=faces, verts_normals=verts_normals
)
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]))
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]).expand(12, 3))
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))
def test_compute_faces_areas_cpu_cuda(self):
num_meshes = 10
max_v = 100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment