Commit 7e87d837 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]fix datasample

parent 64265cec
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.data import BaseDataElement, InstanceData, PixelData
from mmengine.data import InstanceData, PixelData
from mmdet.core.data_structures import DetDataSample
class Det3DDataSample(BaseDataElement):
class Det3DDataSample(DetDataSample):
"""A data structure interface of MMDetection3D. They are used as interfaces
between different components.
......@@ -14,8 +16,33 @@ class Det3DDataSample(BaseDataElement):
training/testing.
- ``gt_instances_3d``(InstanceData): Ground truth of 3D instance
annotations.
- ``gt_instances``(InstanceData): Ground truth of 2D instance
annotations.
- ``pred_instances_3d``(InstanceData): 3D instances of model
predictions.
- For point-cloud 3d object detection task whose input modality
is `use_lidar=True, use_camera=False`, the 3D predictions results
are saved in `pred_instances_3d`.
- For vision-only(monocular/multi-view) 3D object detection task
whose input modality is `use_lidar=False, use_camera=True`, the 3D
predictions are saved in `pred_instances_3d`.
- ``pred_instances``(InstanceData): 2D instances of model
predictions.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 2D predictions
are saved in `pred_instances`.
- ``pts_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on point cloud.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
point cloud are saved in `pts_pred_instances_3d` to distinguish
with `img_pred_instances_3d` which based on image.
- ``img_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on image.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_sem_seg``(PixelData): Ground truth of point cloud
semantic segmentation.
- ``pred_pts_sem_seg``(PixelData): Prediction of point cloud
......@@ -76,9 +103,16 @@ class Det3DDataSample(BaseDataElement):
[0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
) at 0x7fb0d9354280>
) at 0x7fb0d93543d0>
>>> pred_instances = InstanceData(metainfo=meta_info)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances=pred_instances)
>>> assert 'pred_instances' in data_sample
>>> pred_instances_3d = InstanceData(metainfo=meta_info)
>>> pred_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> pred_instances_3d.scores = torch.rand((5, ))
>>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> pred_instances_3d.scores_3d = torch.rand((5, ))
>>> pred_instances_3d.labels_3d = torch.rand((5, ))
>>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
>>> assert 'pred_instances_3d' in data_sample
......@@ -126,30 +160,6 @@ class Det3DDataSample(BaseDataElement):
>>> assert 'segm_seg' in data_sample.gt_pts_sem_seg
"""
@property
def proposals(self) -> InstanceData:
return self._proposals
@proposals.setter
def proposals(self, value: InstanceData):
self.set_field(value, '_proposals', dtype=InstanceData)
@proposals.deleter
def proposals(self):
del self._proposals
@property
def ignored_instances(self) -> InstanceData:
return self._ignored_instances
@ignored_instances.setter
def ignored_instances(self, value: InstanceData):
self.set_field(value, '_ignored_instances', dtype=InstanceData)
@ignored_instances.deleter
def ignored_instances(self):
del self._ignored_instances
@property
def gt_instances_3d(self) -> InstanceData:
return self._gt_instances_3d
......@@ -174,6 +184,30 @@ class Det3DDataSample(BaseDataElement):
def pred_instances_3d(self):
del self._pred_instances_3d
@property
def pts_pred_instances_3d(self) -> InstanceData:
return self._pts_pred_instances_3d
@pts_pred_instances_3d.setter
def pts_pred_instances_3d(self, value: InstanceData):
self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)
@pts_pred_instances_3d.deleter
def pts_pred_instances_3d(self):
del self._pts_pred_instances_3d
@property
def img_pred_instances_3d(self) -> InstanceData:
return self._img_pred_instances_3d
@img_pred_instances_3d.setter
def img_pred_instances_3d(self, value: InstanceData):
self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)
@img_pred_instances_3d.deleter
def img_pred_instances_3d(self):
del self._img_pred_instances_3d
@property
def gt_pts_sem_seg(self) -> PixelData:
return self._gt_pts_sem_seg
......
......@@ -5,7 +5,7 @@ import pytest
import torch
from mmengine.data import InstanceData, PixelData
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.data_structures import Det3DDataSample
def _equal(a, b):
......@@ -15,7 +15,7 @@ def _equal(a, b):
return a == b
class TestDetDataSample(TestCase):
class TestDet3DataSample(TestCase):
def test_init(self):
meta_info = dict(
......@@ -32,54 +32,59 @@ class TestDetDataSample(TestCase):
det3d_data_sample = Det3DDataSample()
# test gt_instances_3d
gt_instances_3d_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.gt_instances_3d.bboxes,
gt_instances_3d_data['bboxes'])
assert _equal(det3d_data_sample.gt_instances_3d.labels,
gt_instances_3d_data['labels'])
assert _equal(det3d_data_sample.gt_instances_3d.masks,
gt_instances_3d_data['masks'])
# test pred_instances
assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
gt_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
gt_instances_3d_data['labels_3d'])
# test pred_instances_3d
pred_instances_3d_data = dict(
bboxes=torch.rand(2, 4),
labels=torch.rand(2),
masks=np.random.rand(2, 2, 2))
bboxes_3d=torch.rand(2, 4),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pred_instances_3d.bboxes,
pred_instances_3d_data['bboxes'])
assert _equal(det3d_data_sample.pred_instances_3d.labels,
pred_instances_3d_data['labels'])
assert _equal(det3d_data_sample.pred_instances_3d.masks,
pred_instances_3d_data['masks'])
# test proposals
proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4))
proposals = InstanceData(**proposals_data)
det3d_data_sample.proposals = proposals
assert 'proposals' in det3d_data_sample
assert _equal(det3d_data_sample.proposals.bboxes,
proposals_data['bboxes'])
assert _equal(det3d_data_sample.proposals.labels,
proposals_data['labels'])
# test ignored_instances
ignored_instances_data = dict(
bboxes=torch.rand(4, 4), labels=torch.rand(4))
ignored_instances = InstanceData(**ignored_instances_data)
det3d_data_sample.ignored_instances = ignored_instances
assert 'ignored_instances' in det3d_data_sample
assert _equal(det3d_data_sample.ignored_instances.bboxes,
ignored_instances_data['bboxes'])
assert _equal(det3d_data_sample.ignored_instances.labels,
ignored_instances_data['labels'])
assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
pred_instances_3d_data['scores_3d'])
# test pts_pred_instances_3d
pts_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
pts_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
pts_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
pts_pred_instances_3d_data['scores_3d'])
# test img_pred_instances_3d
img_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 4),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
img_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
img_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d'])
# test gt_panoptic_seg
gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
......@@ -124,18 +129,34 @@ class TestDetDataSample(TestCase):
det3d_data_sample.pred_pts_sem_seg = torch.rand(2, 4)
def test_deleter(self):
gt_instances_3d_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
tmp_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
det3d_data_sample = Det3DDataSample()
gt_instances_3d = InstanceData(data=gt_instances_3d_data)
gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
del det3d_data_sample.gt_instances_3d
assert 'gt_instances_3d' not in det3d_data_sample
pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pred_instances_3d
assert 'pred_instances_3d' not in det3d_data_sample
pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pts_pred_instances_3d
assert 'pts_pred_instances_3d' not in det3d_data_sample
img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.img_pred_instances_3d
assert 'img_pred_instances_3d' not in det3d_data_sample
pred_pts_panoptic_seg_data = torch.rand(5, 4)
pred_pts_panoptic_seg_data = PixelData(data=pred_pts_panoptic_seg_data)
det3d_data_sample.pred_pts_panoptic_seg_data = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment