"docs/en/_static/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "fdeee889589df413e368b05fd702b5c3f76ac7d3"
Commit b7c826b7 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

Boolean indexing of cameras

Summary: Reasonable to expect bool indexing.

Reviewed By: bottler, kjchalup

Differential Revision: D38741446

fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
parent 60808972
...@@ -385,31 +385,45 @@ class CamerasBase(TensorProperties): ...@@ -385,31 +385,45 @@ class CamerasBase(TensorProperties):
return self.image_size if hasattr(self, "image_size") else None return self.image_size if hasattr(self, "image_size") else None
def __getitem__( def __getitem__(
self, index: Union[int, List[int], torch.LongTensor] self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
) -> "CamerasBase": ) -> "CamerasBase":
""" """
Override for the __getitem__ method in TensorProperties which needs to be Override for the __getitem__ method in TensorProperties which needs to be
refactored. refactored.
Args: Args:
index: an int/list/long tensor used to index all the fields in the cameras given by index: an integer index, list/tensor of integer indices, or tensor of boolean
self._FIELDS. indicators used to filter all the fields in the cameras given by self._FIELDS.
Returns: Returns:
if `index` is an index int/list/long tensor return an instance of the current an instance of the current cameras class with only the values at the selected index.
cameras class with only the values at the selected index.
""" """
kwargs = {} kwargs = {}
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. tensor_types = {
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)): "bool": (torch.BoolTensor, torch.cuda.BoolTensor),
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" "long": (torch.LongTensor, torch.cuda.LongTensor),
}
if not isinstance(
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
) or (
isinstance(index, list)
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
):
msg = (
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
)
raise ValueError(msg % type(index)) raise ValueError(msg % type(index))
if isinstance(index, int): if isinstance(index, int):
index = [index] index = [index]
if max(index) >= len(self): if isinstance(index, tensor_types["bool"]):
if index.ndim != 1 or index.shape[0] != len(self):
raise ValueError(
f"Boolean index of shape {index.shape} does not match cameras"
)
elif max(index) >= len(self):
raise ValueError(f"Index {max(index)} is out of bounds for select cameras") raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
for field in self._FIELDS: for field in self._FIELDS:
......
...@@ -472,7 +472,9 @@ class Meshes: ...@@ -472,7 +472,9 @@ class Meshes:
def __len__(self) -> int: def __len__(self) -> int:
return self._N return self._N
def __getitem__(self, index) -> "Meshes": def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Meshes":
""" """
Args: Args:
index: Specifying the index of the mesh to retrieve. index: Specifying the index of the mesh to retrieve.
......
...@@ -360,7 +360,10 @@ class Pointclouds: ...@@ -360,7 +360,10 @@ class Pointclouds:
def __len__(self) -> int: def __len__(self) -> int:
return self._N return self._N
def __getitem__(self, index) -> "Pointclouds": def __getitem__(
self,
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
) -> "Pointclouds":
""" """
Args: Args:
index: Specifying the index of the cloud to retrieve. index: Specifying the index of the cloud to retrieve.
......
...@@ -501,7 +501,10 @@ class Volumes: ...@@ -501,7 +501,10 @@ class Volumes:
return self._densities.shape[0] return self._densities.shape[0]
def __getitem__( def __getitem__(
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor] self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes": ) -> "Volumes":
""" """
Args: Args:
......
...@@ -181,7 +181,7 @@ class Transform3d: ...@@ -181,7 +181,7 @@ class Transform3d:
return self.get_matrix().shape[0] return self.get_matrix().shape[0]
def __getitem__( def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor] self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d": ) -> "Transform3d":
""" """
Args: Args:
......
...@@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase): ...@@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_cam.device == device) self.assertTrue(new_cam.device == device)
def test_getitem(self): def test_getitem(self):
R_matrix = torch.randn((6, 3, 3)) N_CAMERAS = 6
R_matrix = torch.randn((N_CAMERAS, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix) cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
# Check get item returns an instance of the same class # Check get item returns an instance of the same class
...@@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase): ...@@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(c012.R, R_matrix[0:3, ...]) self.assertClose(c012.R, R_matrix[0:3, ...])
# Check torch.LongTensor index # Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64) SLICE = [1, 3, 5]
index = torch.tensor(SLICE, dtype=torch.int64)
c135 = cam[index] c135 = cam[index]
self.assertEqual(len(c135), 3) self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3)) self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3)) self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check torch.BoolTensor index
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
index = torch.tensor(bool_slice, dtype=torch.bool)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check errors with get item # Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"): with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6] cam[N_CAMERAS]
with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool)
cam[index]
with self.assertRaisesRegex(ValueError, "Invalid index type"): with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)] cam[slice(0, 1)]
with self.assertRaisesRegex(ValueError, "Invalid index type"): with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32) cam[[True, False]]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32)
cam[index] cam[index]
def test_get_full_transform(self): def test_get_full_transform(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment