det3d_data_sample.py 8.85 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
3
4
from typing import Dict, List, Optional, Tuple, Union

import torch
ZCMax's avatar
ZCMax committed
5
from mmengine.data import InstanceData
VVsssssk's avatar
VVsssssk committed
6

zhangshilong's avatar
zhangshilong committed
7
from mmdet.structures import DetDataSample
ZCMax's avatar
ZCMax committed
8
from .point_data import PointData
VVsssssk's avatar
VVsssssk committed
9

VVsssssk's avatar
VVsssssk committed
10
11

class Det3DDataSample(DetDataSample):
VVsssssk's avatar
VVsssssk committed
12
13
14
15
16
17
18
19
20
21
22
    """A data structure interface of MMDetection3D. They are used as interfaces
    between different components.

    The attributes in ``Det3DDataSample`` are divided into several parts:

        - ``proposals``(InstanceData): Region proposals used in two-stage
            detectors.
        - ``ignored_instances``(InstanceData): Instances to be ignored during
            training/testing.
        - ``gt_instances_3d``(InstanceData): Ground truth of 3D instance
            annotations.
VVsssssk's avatar
VVsssssk committed
23
24
        - ``gt_instances``(InstanceData): Ground truth of 2D instance
            annotations.
VVsssssk's avatar
VVsssssk committed
25
26
        - ``pred_instances_3d``(InstanceData): 3D instances of model
            predictions.
VVsssssk's avatar
VVsssssk committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
            - For point-cloud 3d object detection task whose input modality
            is `use_lidar=True, use_camera=False`, the 3D predictions results
            are saved in `pred_instances_3d`.
            - For vision-only(monocular/multi-view) 3D object detection task
            whose input modality is `use_lidar=False, use_camera=True`, the 3D
            predictions are saved in `pred_instances_3d`.
        - ``pred_instances``(InstanceData): 2D instances of model
            predictions.
            -  For multi-modality 3D detection task whose input modality is
            `use_lidar=True, use_camera=True`, the 2D predictions
            are saved in `pred_instances`.
        - ``pts_pred_instances_3d``(InstanceData): 3D instances of model
            predictions based on point cloud.
            -  For multi-modality 3D detection task whose input modality is
            `use_lidar=True, use_camera=True`, the 3D predictions based on
            point cloud are saved in `pts_pred_instances_3d` to distinguish
            with `img_pred_instances_3d` which based on image.
        - ``img_pred_instances_3d``(InstanceData): 3D instances of model
            predictions based on image.
            -  For multi-modality 3D detection task whose input modality is
            `use_lidar=True, use_camera=True`, the 3D predictions based on
            image are saved in `img_pred_instances_3d` to distinguish with
            `pts_pred_instances_3d` which based on point cloud.
ZCMax's avatar
ZCMax committed
50
51
52
53
        - ``gt_pts_seg``(PointData): Ground truth of point cloud
            segmentation.
        - ``pred_pts_seg``(PointData): Prediction of point cloud
            segmentation.
jshilong's avatar
jshilong committed
54
55
        - ``eval_ann_info``(dict): Raw annotation, which will be passed to
            evaluator and do the online evaluation.
VVsssssk's avatar
VVsssssk committed
56
57

    Examples:
ZCMax's avatar
ZCMax committed
58
    >>> from mmengine.data import InstanceData
VVsssssk's avatar
VVsssssk committed
59

zhangshilong's avatar
zhangshilong committed
60
61
    >>> from mmdet3d.structures import Det3DDataSample
    >>> from mmdet3d.structures import BaseInstance3DBoxes
VVsssssk's avatar
VVsssssk committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    >>> data_sample = Det3DDataSample()
    >>> meta_info = dict(img_shape=(800, 1196, 3),
    ...     pad_shape=(800, 1216, 3))
    >>> gt_instances_3d = InstanceData(metainfo=meta_info)
    >>> gt_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7)))
    >>> gt_instances_3d.labels = torch.randint(0,3,(5, ))
    >>> data_sample.gt_instances_3d = gt_instances_3d
    >>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys()
    >>> print(data_sample)
    <Det3DDataSample(

        META INFORMATION

        DATA FIELDS
        _gt_instances_3d: <InstanceData(

VVsssssk's avatar
VVsssssk committed
79
            META INFORMATION
VVsssssk's avatar
VVsssssk committed
80
81
82
            pad_shape: (800, 1216, 3)
            img_shape: (800, 1196, 3)

VVsssssk's avatar
VVsssssk committed
83
            DATA FIELDS
VVsssssk's avatar
VVsssssk committed
84
85
86
87
88
89
90
91
92
93
            labels: tensor([0, 0, 1, 0, 2])
            bboxes: BaseInstance3DBoxes(
            tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
                    [0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
                    [0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
                    [0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519],
                    [0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
        ) at 0x7fb0d9354280>
        gt_instances_3d: <InstanceData(

VVsssssk's avatar
VVsssssk committed
94
            META INFORMATION
VVsssssk's avatar
VVsssssk committed
95
96
97
            pad_shape: (800, 1216, 3)
            img_shape: (800, 1196, 3)

VVsssssk's avatar
VVsssssk committed
98
            DATA FIELDS
VVsssssk's avatar
VVsssssk committed
99
100
101
102
103
104
105
106
107
            labels: tensor([0, 0, 1, 0, 2])
            bboxes: BaseInstance3DBoxes(
            tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
                    [0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
                    [0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
                    [0.4686, 0.7007, 0.4428, 0.0672, 0.3319, 0.3033, 0.8519],
                    [0.9693, 0.5315, 0.4642, 0.9079, 0.2481, 0.1781, 0.9557]]))
        ) at 0x7fb0d9354280>
    ) at 0x7fb0d93543d0>
VVsssssk's avatar
VVsssssk committed
108
109
110
111
112
113
    >>> pred_instances = InstanceData(metainfo=meta_info)
    >>> pred_instances.bboxes = torch.rand((5, 4))
    >>> pred_instances.scores = torch.rand((5, ))
    >>> data_sample = Det3DDataSample(pred_instances=pred_instances)
    >>> assert 'pred_instances' in data_sample

VVsssssk's avatar
VVsssssk committed
114
    >>> pred_instances_3d = InstanceData(metainfo=meta_info)
zhangshilong's avatar
zhangshilong committed
115
    >>> pred_instances_3d.bbox_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
VVsssssk's avatar
VVsssssk committed
116
117
    >>> pred_instances_3d.scores_3d = torch.rand((5, ))
    >>> pred_instances_3d.labels_3d = torch.rand((5, ))
VVsssssk's avatar
VVsssssk committed
118
119
120
121
122
123
124
125
126
127
128
129
130
    >>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
    >>> assert 'pred_instances_3d' in data_sample

    >>> data_sample = Det3DDataSample()
    >>> gt_instances_3d_data = dict(
    ...    bboxes=BaseInstance3DBoxes(torch.rand((2, 7))),
    ...    labels=torch.rand(2))
    >>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
    >>> data_sample.gt_instances_3d = gt_instances_3d
    >>> assert 'gt_instances_3d' in data_sample
    >>> assert 'bboxes' in data_sample.gt_instances_3d

    >>> data_sample = Det3DDataSample()
ZCMax's avatar
ZCMax committed
131
132
133
134
    ... gt_pts_seg_data = dict(
    ...    pts_instance_mask=torch.rand(2),
    ...    pts_semantic_mask=torch.rand(2))
    >>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
VVsssssk's avatar
VVsssssk committed
135
136
137
138
139
140
    >>> print(data_sample)
    <Det3DDataSample(

        META INFORMATION

        DATA FIELDS
ZCMax's avatar
ZCMax committed
141
        gt_pts_seg: <PointData(
VVsssssk's avatar
VVsssssk committed
142
143
144
145

                META INFORMATION

                DATA FIELDS
ZCMax's avatar
ZCMax committed
146
147
148
149
                pts_instance_mask: tensor([0.0576, 0.3067])
                pts_semantic_mask: tensor([0.9267, 0.7455])
            ) at 0x7f654a9c1590>
        _gt_pts_seg: <PointData(
VVsssssk's avatar
VVsssssk committed
150
151
152
153

                META INFORMATION

                DATA FIELDS
ZCMax's avatar
ZCMax committed
154
155
156
157
                pts_instance_mask: tensor([0.0576, 0.3067])
                pts_semantic_mask: tensor([0.9267, 0.7455])
            ) at 0x7f654a9c1590>
    ) at 0x7f654a9c1550>
VVsssssk's avatar
VVsssssk committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    """

    @property
    def gt_instances_3d(self) -> InstanceData:
        return self._gt_instances_3d

    @gt_instances_3d.setter
    def gt_instances_3d(self, value: InstanceData):
        self.set_field(value, '_gt_instances_3d', dtype=InstanceData)

    @gt_instances_3d.deleter
    def gt_instances_3d(self):
        del self._gt_instances_3d

    @property
    def pred_instances_3d(self) -> InstanceData:
        return self._pred_instances_3d

    @pred_instances_3d.setter
    def pred_instances_3d(self, value: InstanceData):
        self.set_field(value, '_pred_instances_3d', dtype=InstanceData)

    @pred_instances_3d.deleter
    def pred_instances_3d(self):
        del self._pred_instances_3d

VVsssssk's avatar
VVsssssk committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    @property
    def pts_pred_instances_3d(self) -> InstanceData:
        return self._pts_pred_instances_3d

    @pts_pred_instances_3d.setter
    def pts_pred_instances_3d(self, value: InstanceData):
        self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)

    @pts_pred_instances_3d.deleter
    def pts_pred_instances_3d(self):
        del self._pts_pred_instances_3d

    @property
    def img_pred_instances_3d(self) -> InstanceData:
        return self._img_pred_instances_3d

    @img_pred_instances_3d.setter
    def img_pred_instances_3d(self, value: InstanceData):
        self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)

    @img_pred_instances_3d.deleter
    def img_pred_instances_3d(self):
        del self._img_pred_instances_3d

VVsssssk's avatar
VVsssssk committed
208
    @property
ZCMax's avatar
ZCMax committed
209
210
    def gt_pts_seg(self) -> PointData:
        return self._gt_pts_seg
VVsssssk's avatar
VVsssssk committed
211

ZCMax's avatar
ZCMax committed
212
213
214
    @gt_pts_seg.setter
    def gt_pts_seg(self, value: PointData):
        self.set_field(value, '_gt_pts_seg', dtype=PointData)
VVsssssk's avatar
VVsssssk committed
215

ZCMax's avatar
ZCMax committed
216
217
218
    @gt_pts_seg.deleter
    def gt_pts_seg(self):
        del self._gt_pts_seg
VVsssssk's avatar
VVsssssk committed
219
220

    @property
ZCMax's avatar
ZCMax committed
221
222
    def pred_pts_seg(self) -> PointData:
        return self._pred_pts_seg
VVsssssk's avatar
VVsssssk committed
223

ZCMax's avatar
ZCMax committed
224
225
226
    @pred_pts_seg.setter
    def pred_pts_seg(self, value: PointData):
        self.set_field(value, '_pred_pts_seg', dtype=PointData)
VVsssssk's avatar
VVsssssk committed
227

ZCMax's avatar
ZCMax committed
228
229
230
    @pred_pts_seg.deleter
    def pred_pts_seg(self):
        del self._pred_pts_seg
zhangshilong's avatar
zhangshilong committed
231
232
233
234
235
236


SampleList = List[Det3DDataSample]
OptSampleList = Optional[SampleList]
ForwardResults = Union[Dict[str, torch.Tensor], List[Det3DDataSample],
                       Tuple[torch.Tensor], torch.Tensor]