Unverified Commit aeb42905 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Enhance] Update data structures (#2155)

* update pointsdata and det3ddatasample

* update eval_ann_info when it is None

* update unittest

* revert

* delete `eval_ann_info`

* remove unittest

* update doc

* Update point_data.py

* Update det3d_data_sample.py

* remove

* update docs

* fix unittest

* update

* Update instance_seg_metric.py

* fix lint
parent 278df1eb
......@@ -14,97 +14,84 @@ class Det3DDataSample(DetDataSample):
The attributes in ``Det3DDataSample`` are divided into several parts:
- ``proposals``(InstanceData): Region proposals used in two-stage
- ``proposals`` (InstanceData): Region proposals used in two-stage
detectors.
- ``ignored_instances``(InstanceData): Instances to be ignored during
- ``ignored_instances`` (InstanceData): Instances to be ignored during
training/testing.
- ``gt_instances_3d``(InstanceData): Ground truth of 3D instance
- ``gt_instances_3d`` (InstanceData): Ground truth of 3D instance
annotations.
- ``gt_instances``(InstanceData): Ground truth of 2D instance
- ``gt_instances`` (InstanceData): Ground truth of 2D instance
annotations.
- ``pred_instances_3d``(InstanceData): 3D instances of model
- ``pred_instances_3d`` (InstanceData): 3D instances of model
predictions.
- For point-cloud 3d object detection task whose input modality
is `use_lidar=True, use_camera=False`, the 3D predictions results
are saved in `pred_instances_3d`.
- For vision-only(monocular/multi-view) 3D object detection task
- For point-cloud 3D object detection task whose input modality is
`use_lidar=True, use_camera=False`, the 3D predictions results are
saved in `pred_instances_3d`.
- For vision-only (monocular/multi-view) 3D object detection task
whose input modality is `use_lidar=False, use_camera=True`, the 3D
predictions are saved in `pred_instances_3d`.
- ``pred_instances``(InstanceData): 2D instances of model
predictions.
- ``pred_instances`` (InstanceData): 2D instances of model predictions.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 2D predictions
are saved in `pred_instances`.
- ``pts_pred_instances_3d``(InstanceData): 3D instances of model
`use_lidar=True, use_camera=True`, the 2D predictions are saved in
`pred_instances`.
- ``pts_pred_instances_3d`` (InstanceData): 3D instances of model
predictions based on point cloud.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
point cloud are saved in `pts_pred_instances_3d` to distinguish
with `img_pred_instances_3d` which based on image.
- ``img_pred_instances_3d``(InstanceData): 3D instances of model
- ``img_pred_instances_3d`` (InstanceData): 3D instances of model
predictions based on image.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_seg``(PointData): Ground truth of point cloud
segmentation.
- ``pred_pts_seg``(PointData): Prediction of point cloud
segmentation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation.
- ``gt_pts_seg`` (PointData): Ground truth of point cloud segmentation.
- ``pred_pts_seg`` (PointData): Prediction of point cloud segmentation.
- ``eval_ann_info`` (dict or None): Raw annotation, which will be
passed to evaluator and do the online evaluation.
Examples:
>>> import torch
>>> from mmengine.structures import InstanceData
>>> from mmdet3d.structures import Det3DDataSample
>>> from mmdet3d.structures import BaseInstance3DBoxes
>>> data_sample = Det3DDataSample()
>>> meta_info = dict(img_shape=(800, 1196, 3),
>>> meta_info = dict(
... img_shape=(800, 1196, 3),
... pad_shape=(800, 1216, 3))
>>> gt_instances_3d = InstanceData(metainfo=meta_info)
>>> gt_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> gt_instances_3d.labels = torch.randint(0,3,(5, ))
>>> gt_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> gt_instances_3d.labels_3d = torch.randint(0, 3, (5,))
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys()
>>> len(data_sample.gt_instances_3d)
5
>>> print(data_sample)
<Det3DDataSample(
META INFORMATION
DATA FIELDS
_gt_instances_3d: <InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
labels: tensor([0, 0, 1, 0, 2])
bboxes: BaseInstance3DBoxes(
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
[0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519],
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
) at 0x7fb0d9354280>
gt_instances_3d: <InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
labels: tensor([0, 0, 1, 0, 2])
bboxes: BaseInstance3DBoxes(
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
[0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519],
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
) at 0x7fb0d9354280>
) at 0x7fb0d93543d0>
labels_3d: tensor([1, 0, 2, 0, 1])
bboxes_3d: BaseInstance3DBoxes(
tensor([[1.9115e-01, 3.6061e-01, 6.7707e-01, 5.2902e-01, 8.0736e-01, 8.2759e-01, # noqa E501
2.4328e-01],
[5.6272e-01, 2.7508e-01, 5.7966e-01, 9.2410e-01, 3.0456e-01, 1.8912e-01, # noqa E501
3.3176e-01],
[8.1069e-01, 2.8684e-01, 7.7689e-01, 9.2397e-02, 5.5849e-01, 3.8007e-01, # noqa E501
4.6719e-01],
[6.6346e-01, 4.8005e-01, 5.2318e-02, 4.4137e-01, 4.1163e-01, 8.9339e-01, # noqa E501
7.2847e-01],
[2.4800e-01, 7.1944e-01, 3.4766e-01, 7.8583e-01, 8.5507e-01, 6.3729e-02, # noqa E501
7.5161e-05]]))
) at 0x7f7e29de3a00>
) at 0x7f7e2a0e8640>
>>> pred_instances = InstanceData(metainfo=meta_info)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5, ))
......@@ -112,7 +99,8 @@ class Det3DDataSample(DetDataSample):
>>> assert 'pred_instances' in data_sample
>>> pred_instances_3d = InstanceData(metainfo=meta_info)
>>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(
... torch.rand((5, 7)))
>>> pred_instances_3d.scores_3d = torch.rand((5, ))
>>> pred_instances_3d.labels_3d = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
......@@ -120,41 +108,30 @@ class Det3DDataSample(DetDataSample):
>>> data_sample = Det3DDataSample()
>>> gt_instances_3d_data = dict(
... bboxes=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels=torch.rand(2))
... bboxes_3d=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels_3d=torch.rand(2))
>>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'gt_instances_3d' in data_sample
>>> assert 'bboxes' in data_sample.gt_instances_3d
>>> assert 'bboxes_3d' in data_sample.gt_instances_3d
>>> from mmdet3d.structures import PointData
>>> data_sample = Det3DDataSample()
... gt_pts_seg_data = dict(
>>> gt_pts_seg_data = dict(
... pts_instance_mask=torch.rand(2),
... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
>>> print(data_sample)
<Det3DDataSample(
META INFORMATION
DATA FIELDS
gt_pts_seg: <PointData(
META INFORMATION
DATA FIELDS
pts_instance_mask: tensor([0.0576, 0.3067])
pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7f654a9c1590>
_gt_pts_seg: <PointData(
META INFORMATION
DATA FIELDS
pts_instance_mask: tensor([0.0576, 0.3067])
pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7f654a9c1590>
) at 0x7f654a9c1550>
pts_semantic_mask: tensor([0.7199, 0.4006])
pts_instance_mask: tensor([0.7363, 0.8096])
) at 0x7f7e2962cc40>
) at 0x7f7e29ff0d60>
"""
@property
......@@ -162,11 +139,11 @@ class Det3DDataSample(DetDataSample):
return self._gt_instances_3d
@gt_instances_3d.setter
def gt_instances_3d(self, value: InstanceData):
def gt_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_gt_instances_3d', dtype=InstanceData)
@gt_instances_3d.deleter
def gt_instances_3d(self):
def gt_instances_3d(self) -> None:
del self._gt_instances_3d
@property
......@@ -174,11 +151,11 @@ class Det3DDataSample(DetDataSample):
return self._pred_instances_3d
@pred_instances_3d.setter
def pred_instances_3d(self, value: InstanceData):
def pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pred_instances_3d', dtype=InstanceData)
@pred_instances_3d.deleter
def pred_instances_3d(self):
def pred_instances_3d(self) -> None:
del self._pred_instances_3d
@property
......@@ -186,11 +163,11 @@ class Det3DDataSample(DetDataSample):
return self._pts_pred_instances_3d
@pts_pred_instances_3d.setter
def pts_pred_instances_3d(self, value: InstanceData):
def pts_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)
@pts_pred_instances_3d.deleter
def pts_pred_instances_3d(self):
def pts_pred_instances_3d(self) -> None:
del self._pts_pred_instances_3d
@property
......@@ -198,11 +175,11 @@ class Det3DDataSample(DetDataSample):
return self._img_pred_instances_3d
@img_pred_instances_3d.setter
def img_pred_instances_3d(self, value: InstanceData):
def img_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)
@img_pred_instances_3d.deleter
def img_pred_instances_3d(self):
def img_pred_instances_3d(self) -> None:
del self._img_pred_instances_3d
@property
......@@ -210,11 +187,11 @@ class Det3DDataSample(DetDataSample):
return self._gt_pts_seg
@gt_pts_seg.setter
def gt_pts_seg(self, value: PointData):
def gt_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_gt_pts_seg', dtype=PointData)
@gt_pts_seg.deleter
def gt_pts_seg(self):
def gt_pts_seg(self) -> None:
del self._gt_pts_seg
@property
......@@ -222,11 +199,11 @@ class Det3DDataSample(DetDataSample):
return self._pred_pts_seg
@pred_pts_seg.setter
def pred_pts_seg(self, value: PointData):
def pred_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_pred_pts_seg', dtype=PointData)
@pred_pts_seg.deleter
def pred_pts_seg(self):
def pred_pts_seg(self) -> None:
del self._pred_pts_seg
......
......@@ -12,7 +12,7 @@ IndexType = Union[str, slice, int, list, torch.LongTensor,
class PointData(BaseDataElement):
"""Data structure for point-level annnotations or predictions.
"""Data structure for point-level annotations or predictions.
All data items in ``data_fields`` of ``PointData`` meet the following
requirements:
......@@ -27,41 +27,41 @@ class PointData(BaseDataElement):
Examples:
>>> metainfo = dict(
... sample_id=random.randint(0, 100))
... sample_idx=random.randint(0, 100))
>>> points = np.random.randint(0, 255, (100, 3))
>>> point_data = PointData(metainfo=metainfo,
... points=points)
>>> print(len(point_data))
>>> (100)
100
>>> # slice
>>> slice_data = pixel_data[10:60]
>>> assert slice_data.shape == (50,)
>>> slice_data = point_data[10:60]
>>> assert len(slice_data) == 50
>>> # set
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100)
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100,))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100,))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100,)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100,)
"""
def __setattr__(self, name: str, value: Sized):
def __setattr__(self, name: str, value: Sized) -> None:
"""setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length
of PointData.
The value must have the attribute of `__len__` and have the same length
of `PointData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')
else:
assert isinstance(value,
Sized), 'value must contain `_len__` attribute'
Sized), 'value must contain `__len__` attribute'
# TODO: make sure the input value share the same length
super().__setattr__(name, value)
__setitem__ = __setattr__
......@@ -69,16 +69,21 @@ class PointData(BaseDataElement):
def __getitem__(self, item: IndexType) -> 'PointData':
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
get the corresponding values according to item.
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
Get the corresponding values according to item.
Returns:
obj:`PointData`: Corresponding values.
:obj:`PointData`: Corresponding values.
"""
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# Mode details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
......@@ -87,8 +92,8 @@ class PointData(BaseDataElement):
if isinstance(item, str):
return getattr(self, item)
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type: ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
......@@ -99,14 +104,14 @@ class PointData(BaseDataElement):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
f' does not match the shape ' \
f'of the indexed tensor ' \
f'in results_filed ' \
'does not match the shape ' \
'of the indexed tensor ' \
'in results_field ' \
f'{len(self)} at ' \
f'first dimension. '
'first dimension.'
for k, v in self.items():
if isinstance(v, torch.Tensor):
......@@ -116,7 +121,7 @@ class PointData(BaseDataElement):
elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(
......@@ -141,16 +146,15 @@ class PointData(BaseDataElement):
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')
'support slice with `bool`')
else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data # type:ignore
return new_data # type: ignore
def __len__(self) -> int:
"""int: the length of PointData"""
"""int: The length of `PointData`."""
if len(self._data_fields) > 0:
return len(self.values()[0])
else:
......
......@@ -16,7 +16,7 @@ def _equal(a, b):
return a == b
class TestDet3DataSample(TestCase):
class TestDet3DDataSample(TestCase):
def test_init(self):
meta_info = dict(
......@@ -33,7 +33,7 @@ class TestDet3DataSample(TestCase):
det3d_data_sample = Det3DDataSample()
# test gt_instances_3d
gt_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
......@@ -44,7 +44,7 @@ class TestDet3DataSample(TestCase):
# test pred_instances_3d
pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4),
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data)
......@@ -59,7 +59,7 @@ class TestDet3DataSample(TestCase):
# test pts_pred_instances_3d
pts_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4),
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
......@@ -74,7 +74,7 @@ class TestDet3DataSample(TestCase):
# test img_pred_instances_3d
img_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4),
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
......@@ -87,7 +87,7 @@ class TestDet3DataSample(TestCase):
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d'])
# test gt_seg
# test gt_pts_seg
gt_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
gt_pts_seg = PointData(**gt_pts_seg_data)
......@@ -98,7 +98,7 @@ class TestDet3DataSample(TestCase):
assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
gt_pts_seg_data['pts_semantic_mask'])
# test pred_seg
# test pred_pts_seg
pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data)
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmdet3d.structures import PointData
class TestPointData(TestCase):
def setup_data(self):
metainfo = dict(sample_idx=random.randint(0, 100))
points = torch.rand((5, 3))
point_data = PointData(metainfo=metainfo, points=points)
return point_data
def test_set_data(self):
point_data = self.setup_data()
# test set '_metainfo_fields' or '_data_fields'
with self.assertRaises(AttributeError):
point_data._metainfo_fields = 1
with self.assertRaises(AttributeError):
point_data._data_fields = 1
point_data.keypoints = torch.rand((5, 2))
assert 'keypoints' in point_data
def test_getitem(self):
point_data = PointData()
# length must be greater than 0
with self.assertRaises(IndexError):
point_data[1]
point_data = self.setup_data()
assert len(point_data) == 5
slice_point_data = point_data[:2]
assert len(slice_point_data) == 2
slice_point_data = point_data[1]
assert len(slice_point_data) == 1
# assert the index should in 0 ~ len(point_data) - 1
with pytest.raises(IndexError):
point_data[5]
# isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
item = torch.Tensor([1, 2, 3, 4]) # float
with pytest.raises(AssertionError):
point_data[item]
# when input is a bool tensor, The shape of
# the input at index 0 should equal to
# the value length in instance_data_field
with pytest.raises(AssertionError):
point_data[item.bool()]
# test LongTensor
long_tensor = torch.randint(5, (2, ))
long_index_point_data = point_data[long_tensor]
assert len(long_index_point_data) == len(long_tensor)
# test BoolTensor
bool_tensor = torch.rand(5) > 0.5
bool_index_point_data = point_data[bool_tensor]
assert len(bool_index_point_data) == bool_tensor.sum()
bool_tensor = torch.rand(5) > 1
empty_point_data = point_data[bool_tensor]
assert len(empty_point_data) == bool_tensor.sum()
# test list index
list_index = [1, 2]
list_index_point_data = point_data[list_index]
assert len(list_index_point_data) == len(list_index)
# test list bool
list_bool = [True, False, True, False, False]
list_bool_point_data = point_data[list_bool]
assert len(list_bool_point_data) == 2
# test numpy
long_numpy = np.random.randint(5, size=2)
long_numpy_point_data = point_data[long_numpy]
assert len(long_numpy_point_data) == len(long_numpy)
bool_numpy = np.random.rand(5) > 0.5
bool_numpy_point_data = point_data[bool_numpy]
assert len(bool_numpy_point_data) == bool_numpy.sum()
def test_len(self):
point_data = self.setup_data()
assert len(point_data) == 5
point_data = PointData()
assert len(point_data) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment