Commit b7c826b7 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

Boolean indexing of cameras

Summary: Reasonable to expect bool indexing.

Reviewed By: bottler, kjchalup

Differential Revision: D38741446

fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
parent 60808972
......@@ -385,31 +385,45 @@ class CamerasBase(TensorProperties):
return self.image_size if hasattr(self, "image_size") else None
def __getitem__(
self, index: Union[int, List[int], torch.LongTensor]
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
) -> "CamerasBase":
"""
Override for the __getitem__ method in TensorProperties which needs to be
refactored.
Args:
index: an int/list/long tensor used to index all the fields in the cameras given by
self._FIELDS.
index: an integer index, list/tensor of integer indices, or tensor of boolean
indicators used to filter all the fields in the cameras given by self._FIELDS.
Returns:
if `index` is an index int/list/long tensor return an instance of the current
cameras class with only the values at the selected index.
an instance of the current cameras class with only the values at the selected index.
"""
kwargs = {}
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
tensor_types = {
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
"long": (torch.LongTensor, torch.cuda.LongTensor),
}
if not isinstance(
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
) or (
isinstance(index, list)
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
):
msg = (
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
)
raise ValueError(msg % type(index))
if isinstance(index, int):
index = [index]
if max(index) >= len(self):
if isinstance(index, tensor_types["bool"]):
if index.ndim != 1 or index.shape[0] != len(self):
raise ValueError(
f"Boolean index of shape {index.shape} does not match cameras"
)
elif max(index) >= len(self):
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
for field in self._FIELDS:
......
......@@ -472,7 +472,9 @@ class Meshes:
def __len__(self) -> int:
return self._N
def __getitem__(self, index) -> "Meshes":
def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Meshes":
"""
Args:
index: Specifying the index of the mesh to retrieve.
......
......@@ -360,7 +360,10 @@ class Pointclouds:
def __len__(self) -> int:
return self._N
def __getitem__(self, index) -> "Pointclouds":
def __getitem__(
self,
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
) -> "Pointclouds":
"""
Args:
index: Specifying the index of the cloud to retrieve.
......
......@@ -501,7 +501,10 @@ class Volumes:
return self._densities.shape[0]
def __getitem__(
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:
......
......@@ -181,7 +181,7 @@ class Transform3d:
return self.get_matrix().shape[0]
def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor]
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:
......
......@@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_cam.device == device)
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
N_CAMERAS = 6
R_matrix = torch.randn((N_CAMERAS, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
# Check get item returns an instance of the same class
......@@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(c012.R, R_matrix[0:3, ...])
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
SLICE = [1, 3, 5]
index = torch.tensor(SLICE, dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check torch.BoolTensor index
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
index = torch.tensor(bool_slice, dtype=torch.bool)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6]
cam[N_CAMERAS]
with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool)
cam[index]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32)
cam[[True, False]]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32)
cam[index]
def test_get_full_transform(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment