Unverified Commit aeb42905 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Enhance] Update data structures (#2155)

* update pointsdata and det3ddatasample

* update eval_ann_info when it is None

* update unittest

* revert

* delete `eval_ann_info`

* remove unittest

* update doc

* Update point_data.py

* Update det3d_data_sample.py

* remove

* update docs

* fix unittest

* update

* Update instance_seg_metric.py

* fix lint
parent 278df1eb
...@@ -14,147 +14,124 @@ class Det3DDataSample(DetDataSample): ...@@ -14,147 +14,124 @@ class Det3DDataSample(DetDataSample):
The attributes in ``Det3DDataSample`` are divided into several parts: The attributes in ``Det3DDataSample`` are divided into several parts:
- ``proposals``(InstanceData): Region proposals used in two-stage - ``proposals`` (InstanceData): Region proposals used in two-stage
detectors. detectors.
- ``ignored_instances``(InstanceData): Instances to be ignored during - ``ignored_instances`` (InstanceData): Instances to be ignored during
training/testing. training/testing.
- ``gt_instances_3d``(InstanceData): Ground truth of 3D instance - ``gt_instances_3d`` (InstanceData): Ground truth of 3D instance
annotations. annotations.
- ``gt_instances``(InstanceData): Ground truth of 2D instance - ``gt_instances`` (InstanceData): Ground truth of 2D instance
annotations. annotations.
- ``pred_instances_3d``(InstanceData): 3D instances of model - ``pred_instances_3d`` (InstanceData): 3D instances of model
predictions. predictions.
- For point-cloud 3d object detection task whose input modality - For point-cloud 3D object detection task whose input modality is
is `use_lidar=True, use_camera=False`, the 3D predictions results `use_lidar=True, use_camera=False`, the 3D predictions results are
are saved in `pred_instances_3d`. saved in `pred_instances_3d`.
- For vision-only(monocular/multi-view) 3D object detection task - For vision-only (monocular/multi-view) 3D object detection task
whose input modality is `use_lidar=False, use_camera=True`, the 3D whose input modality is `use_lidar=False, use_camera=True`, the 3D
predictions are saved in `pred_instances_3d`. predictions are saved in `pred_instances_3d`.
- ``pred_instances``(InstanceData): 2D instances of model - ``pred_instances`` (InstanceData): 2D instances of model predictions.
predictions. - For multi-modality 3D detection task whose input modality is
- For multi-modality 3D detection task whose input modality is `use_lidar=True, use_camera=True`, the 2D predictions are saved in
`use_lidar=True, use_camera=True`, the 2D predictions `pred_instances`.
are saved in `pred_instances`. - ``pts_pred_instances_3d`` (InstanceData): 3D instances of model
- ``pts_pred_instances_3d``(InstanceData): 3D instances of model predictions based on point cloud.
predictions based on point cloud. - For multi-modality 3D detection task whose input modality is
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on `use_lidar=True, use_camera=True`, the 3D predictions based on
point cloud are saved in `pts_pred_instances_3d` to distinguish point cloud are saved in `pts_pred_instances_3d` to distinguish
with `img_pred_instances_3d` which based on image. with `img_pred_instances_3d` which based on image.
- ``img_pred_instances_3d``(InstanceData): 3D instances of model - ``img_pred_instances_3d`` (InstanceData): 3D instances of model
predictions based on image. predictions based on image.
- For multi-modality 3D detection task whose input modality is - For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on `use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud. `pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_seg``(PointData): Ground truth of point cloud - ``gt_pts_seg`` (PointData): Ground truth of point cloud segmentation.
segmentation. - ``pred_pts_seg`` (PointData): Prediction of point cloud segmentation.
- ``pred_pts_seg``(PointData): Prediction of point cloud - ``eval_ann_info`` (dict or None): Raw annotation, which will be
segmentation. passed to evaluator and do the online evaluation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation.
Examples: Examples:
>>> from mmengine.structures import InstanceData >>> import torch
>>> from mmengine.structures import InstanceData
>>> from mmdet3d.structures import Det3DDataSample
>>> from mmdet3d.structures import BaseInstance3DBoxes >>> from mmdet3d.structures import Det3DDataSample
>>> from mmdet3d.structures import BaseInstance3DBoxes
>>> data_sample = Det3DDataSample()
>>> meta_info = dict(img_shape=(800, 1196, 3), >>> data_sample = Det3DDataSample()
... pad_shape=(800, 1216, 3)) >>> meta_info = dict(
>>> gt_instances_3d = InstanceData(metainfo=meta_info) ... img_shape=(800, 1196, 3),
>>> gt_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7))) ... pad_shape=(800, 1216, 3))
>>> gt_instances_3d.labels = torch.randint(0,3,(5, )) >>> gt_instances_3d = InstanceData(metainfo=meta_info)
>>> data_sample.gt_instances_3d = gt_instances_3d >>> gt_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys() >>> gt_instances_3d.labels_3d = torch.randint(0, 3, (5,))
>>> print(data_sample) >>> data_sample.gt_instances_3d = gt_instances_3d
<Det3DDataSample( >>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys()
>>> len(data_sample.gt_instances_3d)
META INFORMATION 5
>>> print(data_sample)
DATA FIELDS <Det3DDataSample(
_gt_instances_3d: <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS DATA FIELDS
labels: tensor([0, 0, 1, 0, 2]) gt_instances_3d: <InstanceData(
bboxes: BaseInstance3DBoxes( META INFORMATION
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944], img_shape: (800, 1196, 3)
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326], pad_shape: (800, 1216, 3)
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917], DATA FIELDS
[0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519], labels_3d: tensor([1, 0, 2, 0, 1])
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]])) bboxes_3d: BaseInstance3DBoxes(
) at 0x7fb0d9354280> tensor([[1.9115e-01, 3.6061e-01, 6.7707e-01, 5.2902e-01, 8.0736e-01, 8.2759e-01, # noqa E501
gt_instances_3d: <InstanceData( 2.4328e-01],
[5.6272e-01, 2.7508e-01, 5.7966e-01, 9.2410e-01, 3.0456e-01, 1.8912e-01, # noqa E501
3.3176e-01],
[8.1069e-01, 2.8684e-01, 7.7689e-01, 9.2397e-02, 5.5849e-01, 3.8007e-01, # noqa E501
4.6719e-01],
[6.6346e-01, 4.8005e-01, 5.2318e-02, 4.4137e-01, 4.1163e-01, 8.9339e-01, # noqa E501
7.2847e-01],
[2.4800e-01, 7.1944e-01, 3.4766e-01, 7.8583e-01, 8.5507e-01, 6.3729e-02, # noqa E501
7.5161e-05]]))
) at 0x7f7e29de3a00>
) at 0x7f7e2a0e8640>
>>> pred_instances = InstanceData(metainfo=meta_info)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances=pred_instances)
>>> assert 'pred_instances' in data_sample
>>> pred_instances_3d = InstanceData(metainfo=meta_info)
>>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(
... torch.rand((5, 7)))
>>> pred_instances_3d.scores_3d = torch.rand((5, ))
>>> pred_instances_3d.labels_3d = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
>>> assert 'pred_instances_3d' in data_sample
>>> data_sample = Det3DDataSample()
>>> gt_instances_3d_data = dict(
... bboxes_3d=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels_3d=torch.rand(2))
>>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'gt_instances_3d' in data_sample
>>> assert 'bboxes_3d' in data_sample.gt_instances_3d
>>> from mmdet3d.structures import PointData
>>> data_sample = Det3DDataSample()
>>> gt_pts_seg_data = dict(
... pts_instance_mask=torch.rand(2),
... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
>>> print(data_sample)
<Det3DDataSample(
META INFORMATION META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS DATA FIELDS
labels: tensor([0, 0, 1, 0, 2]) gt_pts_seg: <PointData(
bboxes: BaseInstance3DBoxes( META INFORMATION
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944], DATA FIELDS
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326], pts_semantic_mask: tensor([0.7199, 0.4006])
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917], pts_instance_mask: tensor([0.7363, 0.8096])
[0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519], ) at 0x7f7e2962cc40>
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]])) ) at 0x7f7e29ff0d60>
) at 0x7fb0d9354280>
) at 0x7fb0d93543d0>
>>> pred_instances = InstanceData(metainfo=meta_info)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances=pred_instances)
>>> assert 'pred_instances' in data_sample
>>> pred_instances_3d = InstanceData(metainfo=meta_info)
>>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> pred_instances_3d.scores_3d = torch.rand((5, ))
>>> pred_instances_3d.labels_3d = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
>>> assert 'pred_instances_3d' in data_sample
>>> data_sample = Det3DDataSample()
>>> gt_instances_3d_data = dict(
... bboxes=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels=torch.rand(2))
>>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'gt_instances_3d' in data_sample
>>> assert 'bboxes' in data_sample.gt_instances_3d
>>> data_sample = Det3DDataSample()
... gt_pts_seg_data = dict(
... pts_instance_mask=torch.rand(2),
... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
>>> print(data_sample)
<Det3DDataSample(
META INFORMATION
DATA FIELDS
gt_pts_seg: <PointData(
META INFORMATION
DATA FIELDS
pts_instance_mask: tensor([0.0576, 0.3067])
pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7f654a9c1590>
_gt_pts_seg: <PointData(
META INFORMATION
DATA FIELDS
pts_instance_mask: tensor([0.0576, 0.3067])
pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7f654a9c1590>
) at 0x7f654a9c1550>
""" """
@property @property
...@@ -162,11 +139,11 @@ class Det3DDataSample(DetDataSample): ...@@ -162,11 +139,11 @@ class Det3DDataSample(DetDataSample):
return self._gt_instances_3d return self._gt_instances_3d
@gt_instances_3d.setter @gt_instances_3d.setter
def gt_instances_3d(self, value: InstanceData): def gt_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_gt_instances_3d', dtype=InstanceData) self.set_field(value, '_gt_instances_3d', dtype=InstanceData)
@gt_instances_3d.deleter @gt_instances_3d.deleter
def gt_instances_3d(self): def gt_instances_3d(self) -> None:
del self._gt_instances_3d del self._gt_instances_3d
@property @property
...@@ -174,11 +151,11 @@ class Det3DDataSample(DetDataSample): ...@@ -174,11 +151,11 @@ class Det3DDataSample(DetDataSample):
return self._pred_instances_3d return self._pred_instances_3d
@pred_instances_3d.setter @pred_instances_3d.setter
def pred_instances_3d(self, value: InstanceData): def pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pred_instances_3d', dtype=InstanceData) self.set_field(value, '_pred_instances_3d', dtype=InstanceData)
@pred_instances_3d.deleter @pred_instances_3d.deleter
def pred_instances_3d(self): def pred_instances_3d(self) -> None:
del self._pred_instances_3d del self._pred_instances_3d
@property @property
...@@ -186,11 +163,11 @@ class Det3DDataSample(DetDataSample): ...@@ -186,11 +163,11 @@ class Det3DDataSample(DetDataSample):
return self._pts_pred_instances_3d return self._pts_pred_instances_3d
@pts_pred_instances_3d.setter @pts_pred_instances_3d.setter
def pts_pred_instances_3d(self, value: InstanceData): def pts_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData) self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)
@pts_pred_instances_3d.deleter @pts_pred_instances_3d.deleter
def pts_pred_instances_3d(self): def pts_pred_instances_3d(self) -> None:
del self._pts_pred_instances_3d del self._pts_pred_instances_3d
@property @property
...@@ -198,11 +175,11 @@ class Det3DDataSample(DetDataSample): ...@@ -198,11 +175,11 @@ class Det3DDataSample(DetDataSample):
return self._img_pred_instances_3d return self._img_pred_instances_3d
@img_pred_instances_3d.setter @img_pred_instances_3d.setter
def img_pred_instances_3d(self, value: InstanceData): def img_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData) self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)
@img_pred_instances_3d.deleter @img_pred_instances_3d.deleter
def img_pred_instances_3d(self): def img_pred_instances_3d(self) -> None:
del self._img_pred_instances_3d del self._img_pred_instances_3d
@property @property
...@@ -210,11 +187,11 @@ class Det3DDataSample(DetDataSample): ...@@ -210,11 +187,11 @@ class Det3DDataSample(DetDataSample):
return self._gt_pts_seg return self._gt_pts_seg
@gt_pts_seg.setter @gt_pts_seg.setter
def gt_pts_seg(self, value: PointData): def gt_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_gt_pts_seg', dtype=PointData) self.set_field(value, '_gt_pts_seg', dtype=PointData)
@gt_pts_seg.deleter @gt_pts_seg.deleter
def gt_pts_seg(self): def gt_pts_seg(self) -> None:
del self._gt_pts_seg del self._gt_pts_seg
@property @property
...@@ -222,11 +199,11 @@ class Det3DDataSample(DetDataSample): ...@@ -222,11 +199,11 @@ class Det3DDataSample(DetDataSample):
return self._pred_pts_seg return self._pred_pts_seg
@pred_pts_seg.setter @pred_pts_seg.setter
def pred_pts_seg(self, value: PointData): def pred_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_pred_pts_seg', dtype=PointData) self.set_field(value, '_pred_pts_seg', dtype=PointData)
@pred_pts_seg.deleter @pred_pts_seg.deleter
def pred_pts_seg(self): def pred_pts_seg(self) -> None:
del self._pred_pts_seg del self._pred_pts_seg
......
...@@ -12,7 +12,7 @@ IndexType = Union[str, slice, int, list, torch.LongTensor, ...@@ -12,7 +12,7 @@ IndexType = Union[str, slice, int, list, torch.LongTensor,
class PointData(BaseDataElement): class PointData(BaseDataElement):
"""Data structure for point-level annnotations or predictions. """Data structure for point-level annotations or predictions.
All data items in ``data_fields`` of ``PointData`` meet the following All data items in ``data_fields`` of ``PointData`` meet the following
requirements: requirements:
...@@ -27,41 +27,41 @@ class PointData(BaseDataElement): ...@@ -27,41 +27,41 @@ class PointData(BaseDataElement):
Examples: Examples:
>>> metainfo = dict( >>> metainfo = dict(
... sample_id=random.randint(0, 100)) ... sample_idx=random.randint(0, 100))
>>> points = np.random.randint(0, 255, (100, 3)) >>> points = np.random.randint(0, 255, (100, 3))
>>> point_data = PointData(metainfo=metainfo, >>> point_data = PointData(metainfo=metainfo,
... points=points) ... points=points)
>>> print(len(point_data)) >>> print(len(point_data))
>>> (100) 100
>>> # slice >>> # slice
>>> slice_data = pixel_data[10:60] >>> slice_data = point_data[10:60]
>>> assert slice_data.shape == (50,) >>> assert len(slice_data) == 50
>>> # set >>> # set
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100)) >>> point_data.pts_semantic_mask = torch.randint(0, 255, (100,))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100)) >>> point_data.pts_instance_mask = torch.randint(0, 255, (100,))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100) >>> assert tuple(point_data.pts_semantic_mask.shape) == (100,)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100) >>> assert tuple(point_data.pts_instance_mask.shape) == (100,)
""" """
def __setattr__(self, name: str, value: Sized): def __setattr__(self, name: str, value: Sized) -> None:
"""setattr is only used to set data. """setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length The value must have the attribute of `__len__` and have the same length
of PointData. of `PointData`.
""" """
if name in ('_metainfo_fields', '_data_fields'): if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name): if not hasattr(self, name):
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError( raise AttributeError(f'{name} has been used as a '
f'{name} has been used as a ' 'private attribute, which is immutable.')
f'private attribute, which is immutable. ')
else: else:
assert isinstance(value, assert isinstance(value,
Sized), 'value must contain `_len__` attribute' Sized), 'value must contain `__len__` attribute'
# TODO: make sure the input value share the same length
super().__setattr__(name, value) super().__setattr__(name, value)
__setitem__ = __setattr__ __setitem__ = __setattr__
...@@ -69,16 +69,21 @@ class PointData(BaseDataElement): ...@@ -69,16 +69,21 @@ class PointData(BaseDataElement):
def __getitem__(self, item: IndexType) -> 'PointData': def __getitem__(self, item: IndexType) -> 'PointData':
""" """
Args: Args:
item (str, obj:`slice`, item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`): :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
get the corresponding values according to item. Get the corresponding values according to item.
Returns: Returns:
obj:`PointData`: Corresponding values. :obj:`PointData`: Corresponding values.
""" """
if isinstance(item, list): if isinstance(item, list):
item = np.array(item) item = np.array(item)
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# Mode details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item) item = torch.from_numpy(item)
assert isinstance( assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
...@@ -87,8 +92,8 @@ class PointData(BaseDataElement): ...@@ -87,8 +92,8 @@ class PointData(BaseDataElement):
if isinstance(item, str): if isinstance(item, str):
return getattr(self, item) return getattr(self, item)
if type(item) == int: if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore if item >= len(self) or item < -len(self): # type: ignore
raise IndexError(f'Index {item} out of range!') raise IndexError(f'Index {item} out of range!')
else: else:
# keep the dimension # keep the dimension
...@@ -99,14 +104,14 @@ class PointData(BaseDataElement): ...@@ -99,14 +104,14 @@ class PointData(BaseDataElement):
assert item.dim() == 1, 'Only support to get the' \ assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.' ' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \ assert len(item) == len(self), 'The shape of the ' \
f' input(BoolTensor)) ' \ 'input(BoolTensor) ' \
f'{len(item)} ' \ f'{len(item)} ' \
f' does not match the shape ' \ 'does not match the shape ' \
f'of the indexed tensor ' \ 'of the indexed tensor ' \
f'in results_filed ' \ 'in results_field ' \
f'{len(self)} at ' \ f'{len(self)} at ' \
f'first dimension. ' 'first dimension.'
for k, v in self.items(): for k, v in self.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -116,7 +121,7 @@ class PointData(BaseDataElement): ...@@ -116,7 +121,7 @@ class PointData(BaseDataElement):
elif isinstance( elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__') v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')): and hasattr(v, 'cat')):
# convert to indexes from boolTensor # convert to indexes from BoolTensor
if isinstance(item, if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)): (torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view( indexes = torch.nonzero(item).view(
...@@ -141,16 +146,15 @@ class PointData(BaseDataElement): ...@@ -141,16 +146,15 @@ class PointData(BaseDataElement):
raise ValueError( raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no ' f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not ' 'attribute of `cat`, so it does not '
f'support slice with `bool`') 'support slice with `bool`')
else: else:
# item is a slice # item is a slice
for k, v in self.items(): for k, v in self.items():
new_data[k] = v[item] new_data[k] = v[item]
return new_data # type:ignore return new_data # type: ignore
def __len__(self) -> int: def __len__(self) -> int:
"""int: the length of PointData""" """int: The length of `PointData`."""
if len(self._data_fields) > 0: if len(self._data_fields) > 0:
return len(self.values()[0]) return len(self.values()[0])
else: else:
......
...@@ -16,7 +16,7 @@ def _equal(a, b): ...@@ -16,7 +16,7 @@ def _equal(a, b):
return a == b return a == b
class TestDet3DataSample(TestCase): class TestDet3DDataSample(TestCase):
def test_init(self): def test_init(self):
meta_info = dict( meta_info = dict(
...@@ -33,7 +33,7 @@ class TestDet3DataSample(TestCase): ...@@ -33,7 +33,7 @@ class TestDet3DataSample(TestCase):
det3d_data_sample = Det3DDataSample() det3d_data_sample = Det3DDataSample()
# test gt_instances_3d # test gt_instances_3d
gt_instances_3d_data = dict( gt_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4)) bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
gt_instances_3d = InstanceData(**gt_instances_3d_data) gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample assert 'gt_instances_3d' in det3d_data_sample
...@@ -44,7 +44,7 @@ class TestDet3DataSample(TestCase): ...@@ -44,7 +44,7 @@ class TestDet3DataSample(TestCase):
# test pred_instances_3d # test pred_instances_3d
pred_instances_3d_data = dict( pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4), bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2), labels_3d=torch.rand(2),
scores_3d=torch.rand(2)) scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data) pred_instances_3d = InstanceData(**pred_instances_3d_data)
...@@ -59,7 +59,7 @@ class TestDet3DataSample(TestCase): ...@@ -59,7 +59,7 @@ class TestDet3DataSample(TestCase):
# test pts_pred_instances_3d # test pts_pred_instances_3d
pts_pred_instances_3d_data = dict( pts_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4), bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2), labels_3d=torch.rand(2),
scores_3d=torch.rand(2)) scores_3d=torch.rand(2))
pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data) pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
...@@ -74,7 +74,7 @@ class TestDet3DataSample(TestCase): ...@@ -74,7 +74,7 @@ class TestDet3DataSample(TestCase):
# test img_pred_instances_3d # test img_pred_instances_3d
img_pred_instances_3d_data = dict( img_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4), bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2), labels_3d=torch.rand(2),
scores_3d=torch.rand(2)) scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data) img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
...@@ -87,7 +87,7 @@ class TestDet3DataSample(TestCase): ...@@ -87,7 +87,7 @@ class TestDet3DataSample(TestCase):
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d, assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d']) img_pred_instances_3d_data['scores_3d'])
# test gt_seg # test gt_pts_seg
gt_pts_seg_data = dict( gt_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20)) pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
gt_pts_seg = PointData(**gt_pts_seg_data) gt_pts_seg = PointData(**gt_pts_seg_data)
...@@ -98,7 +98,7 @@ class TestDet3DataSample(TestCase): ...@@ -98,7 +98,7 @@ class TestDet3DataSample(TestCase):
assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask, assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
gt_pts_seg_data['pts_semantic_mask']) gt_pts_seg_data['pts_semantic_mask'])
# test pred_seg # test pred_pts_seg
pred_pts_seg_data = dict( pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20)) pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data) pred_pts_seg = PointData(**pred_pts_seg_data)
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmdet3d.structures import PointData
class TestPointData(TestCase):
def setup_data(self):
metainfo = dict(sample_idx=random.randint(0, 100))
points = torch.rand((5, 3))
point_data = PointData(metainfo=metainfo, points=points)
return point_data
def test_set_data(self):
point_data = self.setup_data()
# test set '_metainfo_fields' or '_data_fields'
with self.assertRaises(AttributeError):
point_data._metainfo_fields = 1
with self.assertRaises(AttributeError):
point_data._data_fields = 1
point_data.keypoints = torch.rand((5, 2))
assert 'keypoints' in point_data
def test_getitem(self):
point_data = PointData()
# length must be greater than 0
with self.assertRaises(IndexError):
point_data[1]
point_data = self.setup_data()
assert len(point_data) == 5
slice_point_data = point_data[:2]
assert len(slice_point_data) == 2
slice_point_data = point_data[1]
assert len(slice_point_data) == 1
# assert the index should in 0 ~ len(point_data) - 1
with pytest.raises(IndexError):
point_data[5]
# isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
item = torch.Tensor([1, 2, 3, 4]) # float
with pytest.raises(AssertionError):
point_data[item]
# when input is a bool tensor, The shape of
# the input at index 0 should equal to
# the value length in instance_data_field
with pytest.raises(AssertionError):
point_data[item.bool()]
# test LongTensor
long_tensor = torch.randint(5, (2, ))
long_index_point_data = point_data[long_tensor]
assert len(long_index_point_data) == len(long_tensor)
# test BoolTensor
bool_tensor = torch.rand(5) > 0.5
bool_index_point_data = point_data[bool_tensor]
assert len(bool_index_point_data) == bool_tensor.sum()
bool_tensor = torch.rand(5) > 1
empty_point_data = point_data[bool_tensor]
assert len(empty_point_data) == bool_tensor.sum()
# test list index
list_index = [1, 2]
list_index_point_data = point_data[list_index]
assert len(list_index_point_data) == len(list_index)
# test list bool
list_bool = [True, False, True, False, False]
list_bool_point_data = point_data[list_bool]
assert len(list_bool_point_data) == 2
# test numpy
long_numpy = np.random.randint(5, size=2)
long_numpy_point_data = point_data[long_numpy]
assert len(long_numpy_point_data) == len(long_numpy)
bool_numpy = np.random.rand(5) > 0.5
bool_numpy_point_data = point_data[bool_numpy]
assert len(bool_numpy_point_data) == bool_numpy.sum()
def test_len(self):
point_data = self.setup_data()
assert len(point_data) == 5
point_data = PointData()
assert len(point_data) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment