You need to sign in or sign up before continuing.
Commit 7f2f95f2 authored by Georgia Gkioxari's avatar Georgia Gkioxari Committed by Facebook GitHub Bot
Browse files

detach for meshes, pointclouds, textures

Summary: Add `detach` for Meshes, Pointclouds, Textures

Reviewed By: nikhilaravi

Differential Revision: D23070418

fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
parent 5852b74d
...@@ -242,6 +242,13 @@ class TexturesBase(object): ...@@ -242,6 +242,13 @@ class TexturesBase(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def detach(self):
"""
Each texture class should implement a method
to detach all necessary internal tensors.
"""
raise NotImplementedError()
def __getitem__(self, index): def __getitem__(self, index):
""" """
Each texture class should implement a method Each texture class should implement a method
...@@ -388,6 +395,8 @@ class TexturesAtlas(TexturesBase): ...@@ -388,6 +395,8 @@ class TexturesAtlas(TexturesBase):
def clone(self): def clone(self):
tex = self.__class__(atlas=self.atlas_padded().clone()) tex = self.__class__(atlas=self.atlas_padded().clone())
if self._atlas_list is not None:
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
num_faces = ( num_faces = (
self._num_faces_per_mesh.clone() self._num_faces_per_mesh.clone()
if torch.is_tensor(self._num_faces_per_mesh) if torch.is_tensor(self._num_faces_per_mesh)
...@@ -397,6 +406,19 @@ class TexturesAtlas(TexturesBase): ...@@ -397,6 +406,19 @@ class TexturesAtlas(TexturesBase):
tex._num_faces_per_mesh = num_faces tex._num_faces_per_mesh = num_faces
return tex return tex
def detach(self):
tex = self.__class__(atlas=self.atlas_padded().detach())
if self._atlas_list is not None:
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
num_faces = (
self._num_faces_per_mesh.detach()
if torch.is_tensor(self._num_faces_per_mesh)
else self._num_faces_per_mesh
)
tex.valid = self.valid.detach()
tex._num_faces_per_mesh = num_faces
return tex
def __getitem__(self, index): def __getitem__(self, index):
props = ["atlas_list", "_num_faces_per_mesh"] props = ["atlas_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props=props) new_props = self._getitem(index, props=props)
...@@ -656,6 +678,12 @@ class TexturesUV(TexturesBase): ...@@ -656,6 +678,12 @@ class TexturesUV(TexturesBase):
self.faces_uvs_padded().clone(), self.faces_uvs_padded().clone(),
self.verts_uvs_padded().clone(), self.verts_uvs_padded().clone(),
) )
if self._maps_list is not None:
tex._maps_list = [m.clone() for m in self._maps_list]
if self._verts_uvs_list is not None:
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
if self._faces_uvs_list is not None:
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
num_faces = ( num_faces = (
self._num_faces_per_mesh.clone() self._num_faces_per_mesh.clone()
if torch.is_tensor(self._num_faces_per_mesh) if torch.is_tensor(self._num_faces_per_mesh)
...@@ -665,6 +693,27 @@ class TexturesUV(TexturesBase): ...@@ -665,6 +693,27 @@ class TexturesUV(TexturesBase):
tex.valid = self.valid.clone() tex.valid = self.valid.clone()
return tex return tex
def detach(self):
tex = self.__class__(
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
self.verts_uvs_padded().detach(),
)
if self._maps_list is not None:
tex._maps_list = [m.detach() for m in self._maps_list]
if self._verts_uvs_list is not None:
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
if self._faces_uvs_list is not None:
tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
num_faces = (
self._num_faces_per_mesh.detach()
if torch.is_tensor(self._num_faces_per_mesh)
else self._num_faces_per_mesh
)
tex._num_faces_per_mesh = num_faces
tex.valid = self.valid.detach()
return tex
def __getitem__(self, index): def __getitem__(self, index):
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"] props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props) new_props = self._getitem(index, props)
...@@ -892,8 +941,8 @@ class TexturesVertex(TexturesBase): ...@@ -892,8 +941,8 @@ class TexturesVertex(TexturesBase):
has a D dimensional feature vector. has a D dimensional feature vector.
Args: Args:
verts_features: (N, V, D) tensor giving a feature vector with verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature
artbitrary dimensions for each vertex. vector with artbitrary dimensions for each vertex.
""" """
if isinstance(verts_features, (tuple, list)): if isinstance(verts_features, (tuple, list)):
correct_shape = all( correct_shape = all(
...@@ -948,15 +997,28 @@ class TexturesVertex(TexturesBase): ...@@ -948,15 +997,28 @@ class TexturesVertex(TexturesBase):
tex = self.__class__(self.verts_features_padded().clone()) tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None: if self._verts_features_list is not None:
tex._verts_features_list = [f.clone() for f in self._verts_features_list] tex._verts_features_list = [f.clone() for f in self._verts_features_list]
num_faces = ( num_verts = (
self._num_verts_per_mesh.clone() self._num_verts_per_mesh.clone()
if torch.is_tensor(self._num_verts_per_mesh) if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh else self._num_verts_per_mesh
) )
tex._num_verts_per_mesh = num_faces tex._num_verts_per_mesh = num_verts
tex.valid = self.valid.clone() tex.valid = self.valid.clone()
return tex return tex
def detach(self):
tex = self.__class__(self.verts_features_padded().detach())
if self._verts_features_list is not None:
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
num_verts = (
self._num_verts_per_mesh.detach()
if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh
)
tex._num_verts_per_mesh = num_verts
tex.valid = self.valid.detach()
return tex
def __getitem__(self, index): def __getitem__(self, index):
props = ["verts_features_list", "_num_verts_per_mesh"] props = ["verts_features_list", "_num_verts_per_mesh"]
new_props = self._getitem(index, props) new_props = self._getitem(index, props)
......
...@@ -1138,6 +1138,28 @@ class Meshes(object): ...@@ -1138,6 +1138,28 @@ class Meshes(object):
other.textures = self.textures.clone() other.textures = self.textures.clone()
return other return other
def detach(self):
"""
Detach Meshes object. All internal tensors are detached individually.
Returns:
new Meshes object.
"""
verts_list = self.verts_list()
faces_list = self.faces_list()
new_verts_list = [v.detach() for v in verts_list]
new_faces_list = [f.detach() for f in faces_list]
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.detach())
# Textures is not a tensor but has a detach method
if self.textures is not None:
other.textures = self.textures.detach()
return other
def to(self, device, copy: bool = False): def to(self, device, copy: bool = False):
""" """
Match functionality of torch.Tensor.to() Match functionality of torch.Tensor.to()
......
...@@ -655,6 +655,42 @@ class Pointclouds(object): ...@@ -655,6 +655,42 @@ class Pointclouds(object):
setattr(other, k, v.clone()) setattr(other, k, v.clone())
return other return other
def detach(self):
"""
Detach Pointclouds object. All internal tensors are detached
individually.
Returns:
new Pointclouds object.
"""
# instantiate new pointcloud with the representation which is not None
# (either list or tensor) to save compute.
new_points, new_normals, new_features = None, None, None
if self._points_list is not None:
new_points = [v.detach() for v in self.points_list()]
normals_list = self.normals_list()
features_list = self.features_list()
if normals_list is not None:
new_normals = [n.detach() for n in normals_list]
if features_list is not None:
new_features = [f.detach() for f in features_list]
elif self._points_padded is not None:
new_points = self.points_padded().detach()
normals_padded = self.normals_padded()
features_padded = self.features_padded()
if normals_padded is not None:
new_normals = self.normals_padded().detach()
if features_padded is not None:
new_features = self.features_padded().detach()
other = self.__class__(
points=new_points, normals=new_normals, features=new_features
)
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.detach())
return other
def to(self, device, copy: bool = False): def to(self, device, copy: bool = False):
""" """
Match functionality of torch.Tensor.to() Match functionality of torch.Tensor.to()
......
...@@ -20,6 +20,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -20,6 +20,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
max_f: int = 300, max_f: int = 300,
lists_to_tensors: bool = False, lists_to_tensors: bool = False,
device: str = "cpu", device: str = "cpu",
requires_grad: bool = False,
): ):
""" """
Function to generate a Meshes object of N meshes with Function to generate a Meshes object of N meshes with
...@@ -57,7 +58,12 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -57,7 +58,12 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
# Generate the actual vertices and faces. # Generate the actual vertices and faces.
for i in range(num_meshes): for i in range(num_meshes):
verts = torch.rand((v[i], 3), dtype=torch.float32, device=device) verts = torch.rand(
(v[i], 3),
dtype=torch.float32,
device=device,
requires_grad=requires_grad,
)
faces = torch.randint( faces = torch.randint(
v[i], size=(f[i], 3), dtype=torch.int64, device=device v[i], size=(f[i], 3), dtype=torch.int64, device=device
) )
...@@ -353,6 +359,26 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -353,6 +359,26 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded()) self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded())
self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed()) self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed())
def test_detach(self):
N = 5
mesh = TestMeshes.init_mesh(N, 10, 100, requires_grad=True)
for force in [0, 1]:
if force:
# force mesh to have computed attributes
mesh.verts_packed()
mesh.edges_packed()
mesh.verts_padded()
new_mesh = mesh.detach()
self.assertFalse(new_mesh.verts_packed().requires_grad)
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
self.assertTrue(new_mesh.verts_padded().requires_grad == False)
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
self.assertTrue(newv.requires_grad == False)
self.assertClose(newv, v)
def test_laplacian_packed(self): def test_laplacian_packed(self):
def naive_laplacian_packed(meshes): def naive_laplacian_packed(meshes):
verts_packed = meshes.verts_packed() verts_packed = meshes.verts_packed()
......
...@@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): ...@@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
with_normals: bool = True, with_normals: bool = True,
with_features: bool = True, with_features: bool = True,
min_points: int = 0, min_points: int = 0,
requires_grad: bool = False,
): ):
""" """
Function to generate a Pointclouds object of N meshes with Function to generate a Pointclouds object of N meshes with
...@@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): ...@@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
p.fill_(p[0]) p.fill_(p[0])
points_list = [ points_list = [
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p torch.rand(
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
)
for i in p
] ]
normals_list, features_list = None, None normals_list, features_list = None, None
if with_normals: if with_normals:
normals_list = [ normals_list = [
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p torch.rand(
(i, 3),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
] ]
if with_features: if with_features:
features_list = [ features_list = [
torch.rand((i, channels), device=device, dtype=torch.float32) for i in p torch.rand(
(i, channels),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
] ]
if lists_to_tensors: if lists_to_tensors:
...@@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): ...@@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
self.assertCloudsEqual(clouds, new_clouds) self.assertCloudsEqual(clouds, new_clouds)
def test_detach(self):
N = 5
for lists_to_tensors in (True, False):
clouds = self.init_cloud(
N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True
)
for force in (False, True):
if force:
clouds.points_packed()
new_clouds = clouds.detach()
for cloud in new_clouds.points_list():
self.assertTrue(cloud.requires_grad == False)
for normal in new_clouds.normals_list():
self.assertTrue(normal.requires_grad == False)
for feats in new_clouds.features_list():
self.assertTrue(feats.requires_grad == False)
for attrib in [
"points_packed",
"normals_packed",
"features_packed",
"points_padded",
"normals_padded",
"features_padded",
]:
self.assertTrue(
getattr(new_clouds, attrib)().requires_grad == False
)
self.assertCloudsEqual(clouds, new_clouds)
def assertCloudsEqual(self, cloud1, cloud2): def assertCloudsEqual(self, cloud1, cloud2):
N = len(cloud1) N = len(cloud1)
self.assertEqual(N, len(cloud2)) self.assertEqual(N, len(cloud2))
......
...@@ -113,11 +113,37 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): ...@@ -113,11 +113,37 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
def test_clone(self): def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128))) tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex.verts_features_list()
tex_cloned = tex.clone() tex_cloned = tex.clone()
self.assertSeparate( self.assertSeparate(
tex._verts_features_padded, tex_cloned._verts_features_padded tex._verts_features_padded, tex_cloned._verts_features_padded
) )
self.assertClose(tex._verts_features_padded, tex_cloned._verts_features_padded)
self.assertSeparate(tex.valid, tex_cloned.valid) self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
self.assertClose(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
def test_detach(self):
tex = TexturesVertex(
verts_features=torch.rand(size=(10, 100, 128), requires_grad=True)
)
tex.verts_features_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._verts_features_padded.requires_grad)
self.assertClose(
tex_detached._verts_features_padded, tex._verts_features_padded
)
for i in range(tex._N):
self.assertClose(
tex._verts_features_list[i], tex_detached._verts_features_list[i]
)
self.assertFalse(tex_detached._verts_features_list[i].requires_grad)
def test_extend(self): def test_extend(self):
B = 10 B = 10
...@@ -278,9 +304,25 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): ...@@ -278,9 +304,25 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
def test_clone(self): def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3))) tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex.atlas_list()
tex_cloned = tex.clone() tex_cloned = tex.clone()
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded) self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertClose(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertSeparate(tex.valid, tex_cloned.valid) self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._atlas_list[i], tex_cloned._atlas_list[i])
self.assertClose(tex._atlas_list[i], tex_cloned._atlas_list[i])
def test_detach(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3), requires_grad=True))
tex.atlas_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._atlas_padded.requires_grad)
self.assertClose(tex_detached._atlas_padded, tex._atlas_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._atlas_list[i].requires_grad)
self.assertClose(tex._atlas_list[i], tex_detached._atlas_list[i])
def test_extend(self): def test_extend(self):
B = 10 B = 10
...@@ -478,11 +520,49 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): ...@@ -478,11 +520,49 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
faces_uvs=torch.rand(size=(5, 10, 3)), faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)), verts_uvs=torch.rand(size=(5, 15, 2)),
) )
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_cloned = tex.clone() tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
self.assertSeparate(tex.valid, tex_cloned.valid) self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._maps_padded.requires_grad)
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_extend(self): def test_extend(self):
B = 5 B = 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment