"megatron/vscode:/vscode.git/clone" did not exist on "83d26f036154dfad8cab59a26425c2bccb1089de"
Commit 7e87d837 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]fix datasample

parent 64265cec
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.data import BaseDataElement, InstanceData, PixelData from mmengine.data import InstanceData, PixelData
from mmdet.core.data_structures import DetDataSample
class Det3DDataSample(BaseDataElement):
class Det3DDataSample(DetDataSample):
"""A data structure interface of MMDetection3D. They are used as interfaces """A data structure interface of MMDetection3D. They are used as interfaces
between different components. between different components.
...@@ -14,8 +16,33 @@ class Det3DDataSample(BaseDataElement): ...@@ -14,8 +16,33 @@ class Det3DDataSample(BaseDataElement):
training/testing. training/testing.
- ``gt_instances_3d``(InstanceData): Ground truth of 3D instance - ``gt_instances_3d``(InstanceData): Ground truth of 3D instance
annotations. annotations.
- ``gt_instances``(InstanceData): Ground truth of 2D instance
annotations.
- ``pred_instances_3d``(InstanceData): 3D instances of model - ``pred_instances_3d``(InstanceData): 3D instances of model
predictions. predictions.
- For point-cloud 3d object detection task whose input modality
is `use_lidar=True, use_camera=False`, the 3D predictions results
are saved in `pred_instances_3d`.
- For vision-only(monocular/multi-view) 3D object detection task
whose input modality is `use_lidar=False, use_camera=True`, the 3D
predictions are saved in `pred_instances_3d`.
- ``pred_instances``(InstanceData): 2D instances of model
predictions.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 2D predictions
are saved in `pred_instances`.
- ``pts_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on point cloud.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
point cloud are saved in `pts_pred_instances_3d` to distinguish
with `img_pred_instances_3d` which based on image.
- ``img_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on image.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_sem_seg``(PixelData): Ground truth of point cloud - ``gt_pts_sem_seg``(PixelData): Ground truth of point cloud
semantic segmentation. semantic segmentation.
- ``pred_pts_sem_seg``(PixelData): Prediction of point cloud - ``pred_pts_sem_seg``(PixelData): Prediction of point cloud
...@@ -76,9 +103,16 @@ class Det3DDataSample(BaseDataElement): ...@@ -76,9 +103,16 @@ class Det3DDataSample(BaseDataElement):
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]])) [0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
) at 0x7fb0d9354280> ) at 0x7fb0d9354280>
) at 0x7fb0d93543d0> ) at 0x7fb0d93543d0>
>>> pred_instances = InstanceData(metainfo=meta_info)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances=pred_instances)
>>> assert 'pred_instances' in data_sample
>>> pred_instances_3d = InstanceData(metainfo=meta_info) >>> pred_instances_3d = InstanceData(metainfo=meta_info)
>>> pred_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7))) >>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> pred_instances_3d.scores = torch.rand((5, )) >>> pred_instances_3d.scores_3d = torch.rand((5, ))
>>> pred_instances_3d.labels_3d = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d) >>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
>>> assert 'pred_instances_3d' in data_sample >>> assert 'pred_instances_3d' in data_sample
...@@ -126,30 +160,6 @@ class Det3DDataSample(BaseDataElement): ...@@ -126,30 +160,6 @@ class Det3DDataSample(BaseDataElement):
>>> assert 'segm_seg' in data_sample.gt_pts_sem_seg >>> assert 'segm_seg' in data_sample.gt_pts_sem_seg
""" """
@property
def proposals(self) -> InstanceData:
return self._proposals
@proposals.setter
def proposals(self, value: InstanceData):
self.set_field(value, '_proposals', dtype=InstanceData)
@proposals.deleter
def proposals(self):
del self._proposals
@property
def ignored_instances(self) -> InstanceData:
return self._ignored_instances
@ignored_instances.setter
def ignored_instances(self, value: InstanceData):
self.set_field(value, '_ignored_instances', dtype=InstanceData)
@ignored_instances.deleter
def ignored_instances(self):
del self._ignored_instances
@property @property
def gt_instances_3d(self) -> InstanceData: def gt_instances_3d(self) -> InstanceData:
return self._gt_instances_3d return self._gt_instances_3d
...@@ -174,6 +184,30 @@ class Det3DDataSample(BaseDataElement): ...@@ -174,6 +184,30 @@ class Det3DDataSample(BaseDataElement):
def pred_instances_3d(self): def pred_instances_3d(self):
del self._pred_instances_3d del self._pred_instances_3d
@property
def pts_pred_instances_3d(self) -> InstanceData:
return self._pts_pred_instances_3d
@pts_pred_instances_3d.setter
def pts_pred_instances_3d(self, value: InstanceData):
self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)
@pts_pred_instances_3d.deleter
def pts_pred_instances_3d(self):
del self._pts_pred_instances_3d
@property
def img_pred_instances_3d(self) -> InstanceData:
return self._img_pred_instances_3d
@img_pred_instances_3d.setter
def img_pred_instances_3d(self, value: InstanceData):
self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)
@img_pred_instances_3d.deleter
def img_pred_instances_3d(self):
del self._img_pred_instances_3d
@property @property
def gt_pts_sem_seg(self) -> PixelData: def gt_pts_sem_seg(self) -> PixelData:
return self._gt_pts_sem_seg return self._gt_pts_sem_seg
......
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
from mmengine.data import InstanceData, PixelData from mmengine.data import InstanceData, PixelData
from mmdet3d.core import Det3DDataSample from mmdet3d.core.data_structures import Det3DDataSample
def _equal(a, b): def _equal(a, b):
...@@ -15,7 +15,7 @@ def _equal(a, b): ...@@ -15,7 +15,7 @@ def _equal(a, b):
return a == b return a == b
class TestDetDataSample(TestCase): class TestDet3DataSample(TestCase):
def test_init(self): def test_init(self):
meta_info = dict( meta_info = dict(
...@@ -32,54 +32,59 @@ class TestDetDataSample(TestCase): ...@@ -32,54 +32,59 @@ class TestDetDataSample(TestCase):
det3d_data_sample = Det3DDataSample() det3d_data_sample = Det3DDataSample()
# test gt_instances_3d # test gt_instances_3d
gt_instances_3d_data = dict( gt_instances_3d_data = dict(
bboxes=torch.rand(4, 4), bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
gt_instances_3d = InstanceData(**gt_instances_3d_data) gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample assert 'gt_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.gt_instances_3d.bboxes, assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
gt_instances_3d_data['bboxes']) gt_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.gt_instances_3d.labels, assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
gt_instances_3d_data['labels']) gt_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.gt_instances_3d.masks,
gt_instances_3d_data['masks']) # test pred_instances_3d
# test pred_instances
pred_instances_3d_data = dict( pred_instances_3d_data = dict(
bboxes=torch.rand(2, 4), bboxes_3d=torch.rand(2, 4),
labels=torch.rand(2), labels_3d=torch.rand(2),
masks=np.random.rand(2, 2, 2)) scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data) pred_instances_3d = InstanceData(**pred_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample assert 'pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pred_instances_3d.bboxes, assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
pred_instances_3d_data['bboxes']) pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.labels, assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
pred_instances_3d_data['labels']) pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.masks, assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
pred_instances_3d_data['masks']) pred_instances_3d_data['scores_3d'])
# test proposals # test pts_pred_instances_3d
proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4)) pts_pred_instances_3d_data = dict(
proposals = InstanceData(**proposals_data) bboxes_3d=torch.rand(2, 4),
det3d_data_sample.proposals = proposals labels_3d=torch.rand(2),
assert 'proposals' in det3d_data_sample scores_3d=torch.rand(2))
assert _equal(det3d_data_sample.proposals.bboxes, pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
proposals_data['bboxes']) det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert _equal(det3d_data_sample.proposals.labels, assert 'pts_pred_instances_3d' in det3d_data_sample
proposals_data['labels']) assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
pts_pred_instances_3d_data['bboxes_3d'])
# test ignored_instances assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
ignored_instances_data = dict( pts_pred_instances_3d_data['labels_3d'])
bboxes=torch.rand(4, 4), labels=torch.rand(4)) assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
ignored_instances = InstanceData(**ignored_instances_data) pts_pred_instances_3d_data['scores_3d'])
det3d_data_sample.ignored_instances = ignored_instances
assert 'ignored_instances' in det3d_data_sample # test img_pred_instances_3d
assert _equal(det3d_data_sample.ignored_instances.bboxes, img_pred_instances_3d_data = dict(
ignored_instances_data['bboxes']) bboxes_3d=torch.rand(2, 4),
assert _equal(det3d_data_sample.ignored_instances.labels, labels_3d=torch.rand(2),
ignored_instances_data['labels']) scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
img_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
img_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d'])
# test gt_panoptic_seg # test gt_panoptic_seg
gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4)) gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
...@@ -124,18 +129,34 @@ class TestDetDataSample(TestCase): ...@@ -124,18 +129,34 @@ class TestDetDataSample(TestCase):
det3d_data_sample.pred_pts_sem_seg = torch.rand(2, 4) det3d_data_sample.pred_pts_sem_seg = torch.rand(2, 4)
def test_deleter(self): def test_deleter(self):
gt_instances_3d_data = dict( tmp_instances_3d_data = dict(
bboxes=torch.rand(4, 4), bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
det3d_data_sample = Det3DDataSample() det3d_data_sample = Det3DDataSample()
gt_instances_3d = InstanceData(data=gt_instances_3d_data) gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample assert 'gt_instances_3d' in det3d_data_sample
del det3d_data_sample.gt_instances_3d del det3d_data_sample.gt_instances_3d
assert 'gt_instances_3d' not in det3d_data_sample assert 'gt_instances_3d' not in det3d_data_sample
pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pred_instances_3d
assert 'pred_instances_3d' not in det3d_data_sample
pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pts_pred_instances_3d
assert 'pts_pred_instances_3d' not in det3d_data_sample
img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.img_pred_instances_3d
assert 'img_pred_instances_3d' not in det3d_data_sample
pred_pts_panoptic_seg_data = torch.rand(5, 4) pred_pts_panoptic_seg_data = torch.rand(5, 4)
pred_pts_panoptic_seg_data = PixelData(data=pred_pts_panoptic_seg_data) pred_pts_panoptic_seg_data = PixelData(data=pred_pts_panoptic_seg_data)
det3d_data_sample.pred_pts_panoptic_seg_data = \ det3d_data_sample.pred_pts_panoptic_seg_data = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment