det3d_data_sample.py 8.82 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
3
4
from typing import Dict, List, Optional, Tuple, Union

import torch
5
from mmdet.structures import DetDataSample
6
from mmengine.structures import InstanceData
VVsssssk's avatar
VVsssssk committed
7

ZCMax's avatar
ZCMax committed
8
from .point_data import PointData
VVsssssk's avatar
VVsssssk committed
9

VVsssssk's avatar
VVsssssk committed
10
11

class Det3DDataSample(DetDataSample):
VVsssssk's avatar
VVsssssk committed
12
13
14
15
16
    """A data structure interface of MMDetection3D. They are used as interfaces
    between different components.

    The attributes in ``Det3DDataSample`` are divided into several parts:

17
18
19
20
21
22
23
24
25
26
27
28
29
30
        - ``proposals`` (InstanceData): Region proposals used in two-stage
          detectors.
        - ``ignored_instances`` (InstanceData): Instances to be ignored during
          training/testing.
        - ``gt_instances_3d`` (InstanceData): Ground truth of 3D instance
          annotations.
        - ``gt_instances`` (InstanceData): Ground truth of 2D instance
          annotations.
        - ``pred_instances_3d`` (InstanceData): 3D instances of model
          predictions.
          - For point-cloud 3D object detection task whose input modality is
            `use_lidar=True, use_camera=False`, the 3D predictions results are
            saved in `pred_instances_3d`.
          - For vision-only (monocular/multi-view) 3D object detection task
VVsssssk's avatar
VVsssssk committed
31
32
            whose input modality is `use_lidar=False, use_camera=True`, the 3D
            predictions are saved in `pred_instances_3d`.
33
34
35
36
37
38
39
        - ``pred_instances`` (InstanceData): 2D instances of model predictions.
          - For multi-modality 3D detection task whose input modality is
            `use_lidar=True, use_camera=True`, the 2D predictions are saved in
            `pred_instances`.
        - ``pts_pred_instances_3d`` (InstanceData): 3D instances of model
          predictions based on point cloud.
          - For multi-modality 3D detection task whose input modality is
VVsssssk's avatar
VVsssssk committed
40
41
42
            `use_lidar=True, use_camera=True`, the 3D predictions based on
            point cloud are saved in `pts_pred_instances_3d` to distinguish
            with `img_pred_instances_3d` which based on image.
43
44
45
        - ``img_pred_instances_3d`` (InstanceData): 3D instances of model
          predictions based on image.
          - For multi-modality 3D detection task whose input modality is
VVsssssk's avatar
VVsssssk committed
46
47
48
            `use_lidar=True, use_camera=True`, the 3D predictions based on
            image are saved in `img_pred_instances_3d` to distinguish with
            `pts_pred_instances_3d` which based on point cloud.
49
50
51
52
        - ``gt_pts_seg`` (PointData): Ground truth of point cloud segmentation.
        - ``pred_pts_seg`` (PointData): Prediction of point cloud segmentation.
        - ``eval_ann_info`` (dict or None): Raw annotation, which will be
          passed to evaluator and do the online evaluation.
VVsssssk's avatar
VVsssssk committed
53
54

    Examples:
55
56
57
58
        >>> import torch
        >>> from mmengine.structures import InstanceData

        >>> from mmdet3d.structures import Det3DDataSample
59
        >>> from mmdet3d.structures.bbox_3d import BaseInstance3DBoxes
60
61
62
63
64
65
66
67
68
69
70
71
72
73

        >>> data_sample = Det3DDataSample()
        >>> meta_info = dict(
        ...     img_shape=(800, 1196, 3),
        ...     pad_shape=(800, 1216, 3))
        >>> gt_instances_3d = InstanceData(metainfo=meta_info)
        >>> gt_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
        >>> gt_instances_3d.labels_3d = torch.randint(0, 3, (5,))
        >>> data_sample.gt_instances_3d = gt_instances_3d
        >>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys()
        >>> len(data_sample.gt_instances_3d)
        5
        >>> print(data_sample)
        <Det3DDataSample(
VVsssssk's avatar
VVsssssk committed
74
75
            META INFORMATION
            DATA FIELDS
76
77
78
79
80
81
82
            gt_instances_3d: <InstanceData(
                    META INFORMATION
                    img_shape: (800, 1196, 3)
                    pad_shape: (800, 1216, 3)
                    DATA FIELDS
                    labels_3d: tensor([1, 0, 2, 0, 1])
                    bboxes_3d: BaseInstance3DBoxes(
83
                            tensor([[1.9115e-01, 3.6061e-01, 6.7707e-01, 5.2902e-01, 8.0736e-01, 8.2759e-01,
84
                                2.4328e-01],
85
                                [5.6272e-01, 2.7508e-01, 5.7966e-01, 9.2410e-01, 3.0456e-01, 1.8912e-01,
86
                                3.3176e-01],
87
                                [8.1069e-01, 2.8684e-01, 7.7689e-01, 9.2397e-02, 5.5849e-01, 3.8007e-01,
88
                                4.6719e-01],
89
                                [6.6346e-01, 4.8005e-01, 5.2318e-02, 4.4137e-01, 4.1163e-01, 8.9339e-01,
90
                                7.2847e-01],
91
                                [2.4800e-01, 7.1944e-01, 3.4766e-01, 7.8583e-01, 8.5507e-01, 6.3729e-02,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                                7.5161e-05]]))
                ) at 0x7f7e29de3a00>
        ) at 0x7f7e2a0e8640>
        >>> pred_instances = InstanceData(metainfo=meta_info)
        >>> pred_instances.bboxes = torch.rand((5, 4))
        >>> pred_instances.scores = torch.rand((5, ))
        >>> data_sample = Det3DDataSample(pred_instances=pred_instances)
        >>> assert 'pred_instances' in data_sample

        >>> pred_instances_3d = InstanceData(metainfo=meta_info)
        >>> pred_instances_3d.bboxes_3d = BaseInstance3DBoxes(
        ...     torch.rand((5, 7)))
        >>> pred_instances_3d.scores_3d = torch.rand((5, ))
        >>> pred_instances_3d.labels_3d = torch.rand((5, ))
        >>> data_sample = Det3DDataSample(pred_instances_3d=pred_instances_3d)
        >>> assert 'pred_instances_3d' in data_sample

        >>> data_sample = Det3DDataSample()
        >>> gt_instances_3d_data = dict(
111
112
        ...     bboxes_3d=BaseInstance3DBoxes(torch.rand((2, 7))),
        ...     labels_3d=torch.rand(2))
113
114
115
116
117
118
119
120
        >>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
        >>> data_sample.gt_instances_3d = gt_instances_3d
        >>> assert 'gt_instances_3d' in data_sample
        >>> assert 'bboxes_3d' in data_sample.gt_instances_3d

        >>> from mmdet3d.structures import PointData
        >>> data_sample = Det3DDataSample()
        >>> gt_pts_seg_data = dict(
121
122
        ...     pts_instance_mask=torch.rand(2),
        ...     pts_semantic_mask=torch.rand(2))
123
124
125
        >>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
        >>> print(data_sample)
        <Det3DDataSample(
VVsssssk's avatar
VVsssssk committed
126
127
            META INFORMATION
            DATA FIELDS
128
129
130
131
132
133
134
            gt_pts_seg: <PointData(
                    META INFORMATION
                    DATA FIELDS
                    pts_semantic_mask: tensor([0.7199, 0.4006])
                    pts_instance_mask: tensor([0.7363, 0.8096])
                ) at 0x7f7e2962cc40>
        ) at 0x7f7e29ff0d60>
135
    """  # noqa: E501
VVsssssk's avatar
VVsssssk committed
136
137
138
139
140
141

    @property
    def gt_instances_3d(self) -> InstanceData:
        return self._gt_instances_3d

    @gt_instances_3d.setter
142
    def gt_instances_3d(self, value: InstanceData) -> None:
VVsssssk's avatar
VVsssssk committed
143
144
145
        self.set_field(value, '_gt_instances_3d', dtype=InstanceData)

    @gt_instances_3d.deleter
146
    def gt_instances_3d(self) -> None:
VVsssssk's avatar
VVsssssk committed
147
148
149
150
151
152
153
        del self._gt_instances_3d

    @property
    def pred_instances_3d(self) -> InstanceData:
        return self._pred_instances_3d

    @pred_instances_3d.setter
154
    def pred_instances_3d(self, value: InstanceData) -> None:
VVsssssk's avatar
VVsssssk committed
155
156
157
        self.set_field(value, '_pred_instances_3d', dtype=InstanceData)

    @pred_instances_3d.deleter
158
    def pred_instances_3d(self) -> None:
VVsssssk's avatar
VVsssssk committed
159
160
        del self._pred_instances_3d

VVsssssk's avatar
VVsssssk committed
161
162
163
164
165
    @property
    def pts_pred_instances_3d(self) -> InstanceData:
        return self._pts_pred_instances_3d

    @pts_pred_instances_3d.setter
166
    def pts_pred_instances_3d(self, value: InstanceData) -> None:
VVsssssk's avatar
VVsssssk committed
167
168
169
        self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)

    @pts_pred_instances_3d.deleter
170
    def pts_pred_instances_3d(self) -> None:
VVsssssk's avatar
VVsssssk committed
171
172
173
174
175
176
177
        del self._pts_pred_instances_3d

    @property
    def img_pred_instances_3d(self) -> InstanceData:
        return self._img_pred_instances_3d

    @img_pred_instances_3d.setter
178
    def img_pred_instances_3d(self, value: InstanceData) -> None:
VVsssssk's avatar
VVsssssk committed
179
180
181
        self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)

    @img_pred_instances_3d.deleter
182
    def img_pred_instances_3d(self) -> None:
VVsssssk's avatar
VVsssssk committed
183
184
        del self._img_pred_instances_3d

VVsssssk's avatar
VVsssssk committed
185
    @property
ZCMax's avatar
ZCMax committed
186
187
    def gt_pts_seg(self) -> PointData:
        return self._gt_pts_seg
VVsssssk's avatar
VVsssssk committed
188

ZCMax's avatar
ZCMax committed
189
    @gt_pts_seg.setter
190
    def gt_pts_seg(self, value: PointData) -> None:
ZCMax's avatar
ZCMax committed
191
        self.set_field(value, '_gt_pts_seg', dtype=PointData)
VVsssssk's avatar
VVsssssk committed
192

ZCMax's avatar
ZCMax committed
193
    @gt_pts_seg.deleter
194
    def gt_pts_seg(self) -> None:
ZCMax's avatar
ZCMax committed
195
        del self._gt_pts_seg
VVsssssk's avatar
VVsssssk committed
196
197

    @property
ZCMax's avatar
ZCMax committed
198
199
    def pred_pts_seg(self) -> PointData:
        return self._pred_pts_seg
VVsssssk's avatar
VVsssssk committed
200

ZCMax's avatar
ZCMax committed
201
    @pred_pts_seg.setter
202
    def pred_pts_seg(self, value: PointData) -> None:
ZCMax's avatar
ZCMax committed
203
        self.set_field(value, '_pred_pts_seg', dtype=PointData)
VVsssssk's avatar
VVsssssk committed
204

ZCMax's avatar
ZCMax committed
205
    @pred_pts_seg.deleter
206
    def pred_pts_seg(self) -> None:
ZCMax's avatar
ZCMax committed
207
        del self._pred_pts_seg
zhangshilong's avatar
zhangshilong committed
208
209
210
211
212
213


SampleList = List[Det3DDataSample]
OptSampleList = Optional[SampleList]
ForwardResults = Union[Dict[str, torch.Tensor], List[Det3DDataSample],
                       Tuple[torch.Tensor], torch.Tensor]