"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5185c522b0acb798ed8ebb9084510bbe5b58e73b"
Commit 569e5229 authored by Evgeniy Zheltonozhskiy's avatar Evgeniy Zheltonozhskiy Committed by Facebook GitHub Bot
Browse files

Add check for verts and faces being on same device and also checks for...

Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)

Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/384

Test Plan: `test_meshes` and `test_points`

Reviewed By: gkioxari

Differential Revision: D24730524

Pulled By: nikhilaravi

fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
parent 19340462
...@@ -325,6 +325,13 @@ class Meshes(object): ...@@ -325,6 +325,13 @@ class Meshes(object):
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0: if self._N > 0:
self.device = self._verts_list[0].device self.device = self._verts_list[0].device
if not (
all(v.device == self.device for v in verts)
and all(f.device == self.device for f in faces)
):
raise ValueError(
"All Verts and Faces tensors should be on same device."
)
self._num_verts_per_mesh = torch.tensor( self._num_verts_per_mesh = torch.tensor(
[len(v) for v in self._verts_list], device=self.device [len(v) for v in self._verts_list], device=self.device
) )
...@@ -341,7 +348,6 @@ class Meshes(object): ...@@ -341,7 +348,6 @@ class Meshes(object):
dtype=torch.bool, dtype=torch.bool,
device=self.device, device=self.device,
) )
if (len(self._num_verts_per_mesh.unique()) == 1) and ( if (len(self._num_verts_per_mesh.unique()) == 1) and (
len(self._num_faces_per_mesh.unique()) == 1 len(self._num_faces_per_mesh.unique()) == 1
): ):
...@@ -355,6 +361,10 @@ class Meshes(object): ...@@ -355,6 +361,10 @@ class Meshes(object):
self._N = self._verts_padded.shape[0] self._N = self._verts_padded.shape[0]
self._V = self._verts_padded.shape[1] self._V = self._verts_padded.shape[1]
if verts.device != faces.device:
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
raise ValueError(msg.format(verts.device, faces.device))
self.device = self._verts_padded.device self.device = self._verts_padded.device
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0: if self._N > 0:
......
...@@ -180,11 +180,13 @@ class Pointclouds(object): ...@@ -180,11 +180,13 @@ class Pointclouds(object):
self._num_points_per_cloud = [] self._num_points_per_cloud = []
if self._N > 0: if self._N > 0:
self.device = self._points_list[0].device
for p in self._points_list: for p in self._points_list:
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3): if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
raise ValueError("Clouds in list must be of shape Px3 or empty") raise ValueError("Clouds in list must be of shape Px3 or empty")
if p.device != self.device:
raise ValueError("All points must be on the same device")
self.device = self._points_list[0].device
num_points_per_cloud = torch.tensor( num_points_per_cloud = torch.tensor(
[len(p) for p in self._points_list], device=self.device [len(p) for p in self._points_list], device=self.device
) )
...@@ -261,6 +263,10 @@ class Pointclouds(object): ...@@ -261,6 +263,10 @@ class Pointclouds(object):
raise ValueError( raise ValueError(
"A cloud has mismatched numbers of points and inputs" "A cloud has mismatched numbers of points and inputs"
) )
if d.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
if p > 0: if p > 0:
if d.dim() != 2: if d.dim() != 2:
raise ValueError( raise ValueError(
...@@ -283,6 +289,10 @@ class Pointclouds(object): ...@@ -283,6 +289,10 @@ class Pointclouds(object):
"Inputs tensor must have the right maximum \ "Inputs tensor must have the right maximum \
number of points in each cloud." number of points in each cloud."
) )
if aux_input.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
aux_input_C = aux_input.shape[2] aux_input_C = aux_input.shape[2]
return None, aux_input, aux_input_C return None, aux_input, aux_input_C
else: else:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import random
import unittest import unittest
import numpy as np import numpy as np
...@@ -162,6 +163,29 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -162,6 +163,29 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
torch.tensor([0, 3, 8], dtype=torch.int64), torch.tensor([0, 3, 8], dtype=torch.int64),
) )
def test_init_error(self):
# Check if correct errors are raised when verts/faces are on
# different devices
mesh = TestMeshes.init_mesh(10, 10, 100)
verts_list = mesh.verts_list() # all tensors on cpu
verts_list = [
v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list
]
faces_list = mesh.faces_list()
with self.assertRaises(ValueError) as cm:
Meshes(verts=verts_list, faces=faces_list)
self.assertTrue("same device" in cm.msg)
verts_padded = mesh.verts_padded() # on cpu
verts_padded = verts_padded.to("cuda:0")
faces_padded = mesh.faces_padded()
with self.assertRaises(ValueError) as cm:
Meshes(verts=verts_padded, faces=faces_padded)
self.assertTrue("same device" in cm.msg)
def test_simple_random_meshes(self): def test_simple_random_meshes(self):
# Define the test mesh object either as a list or tensor of faces/verts. # Define the test mesh object either as a list or tensor of faces/verts.
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import random
import unittest import unittest
import numpy as np import numpy as np
...@@ -126,6 +127,44 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): ...@@ -126,6 +127,44 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]), torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
) )
def test_init_error(self):
# Check if correct errors are raised when verts/faces are on
# different devices
clouds = self.init_cloud(10, 100, 5)
points_list = clouds.points_list() # all tensors on cuda:0
points_list = [
p.to("cpu") if random.uniform(0, 1) > 0.5 else p for p in points_list
]
features_list = clouds.features_list()
normals_list = clouds.normals_list()
with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_list, features=features_list, normals=normals_list
)
self.assertTrue("same device" in cm.msg)
points_list = clouds.points_list()
features_list = [
f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list
]
with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_list, features=features_list, normals=normals_list
)
self.assertTrue("same device" in cm.msg)
points_padded = clouds.points_padded() # on cuda:0
features_padded = clouds.features_padded().to("cpu")
normals_padded = clouds.normals_padded()
with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_padded, features=features_padded, normals=normals_padded
)
self.assertTrue("same device" in cm.msg)
def test_all_constructions(self): def test_all_constructions(self):
public_getters = [ public_getters = [
"points_list", "points_list",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment