nuscenes_dataset.py 9.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from os import path as osp
3
from typing import Callable, List, Union
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.structures import LiDARInstance3DBoxes
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class NuScenesDataset(Det3DDataset):
wangtai's avatar
wangtai committed
15
    r"""NuScenes Dataset.
wangtai's avatar
wangtai committed
16
17
18

    This class serves as the API for experiments on the NuScenes Dataset.

zhangwenwei's avatar
zhangwenwei committed
19
20
    Please refer to `NuScenes Dataset <https://www.nuscenes.org/download>`_
    for data downloading.
wangtai's avatar
wangtai committed
21
22

    Args:
VVsssssk's avatar
VVsssssk committed
23
        data_root (str): Path of dataset root.
wangtai's avatar
wangtai committed
24
        ann_file (str): Path of annotation file.
25
26
        pipeline (list[dict]): Pipeline used for data processing.
            Defaults to [].
VVsssssk's avatar
VVsssssk committed
27
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
28
29
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
30
            Defaults to 'LiDAR' in this dataset. Available options includes:
VVsssssk's avatar
VVsssssk committed
31

wangtai's avatar
wangtai committed
32
33
34
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
35
36
37
38
39
40
41
42
43
        load_type (str): Type of loading mode. Defaults to 'frame_based'.

            - 'frame_based': Load all of the instances in the frame.
            - 'mv_image_based': Load all of the instances in the frame and need
                to convert to the FOV-based data type to support image-based
                detector.
            - 'fov_image_based': Only load the instances inside the default
                cam, and need to convert to the FOV-based data type to support
                image-based detector.
44
45
46
47
48
49
50
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
        test_mode (bool): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
51
            Defaults to False.
52
        with_velocity (bool): Whether to include velocity prediction
VVsssssk's avatar
VVsssssk committed
53
            into the experiments. Defaults to True.
54
        use_valid_flag (bool): Whether to use `use_valid_flag` key
55
56
            in the info file as mask to filter gt_boxes and gt_names.
            Defaults to False.
wangtai's avatar
wangtai committed
57
    """
VVsssssk's avatar
VVsssssk committed
58
    METAINFO = {
59
        'classes':
VVsssssk's avatar
VVsssssk committed
60
61
62
63
        ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
         'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
        'version':
        'v1.0-trainval'
zhangwenwei's avatar
zhangwenwei committed
64
65
66
    }

    def __init__(self,
VVsssssk's avatar
VVsssssk committed
67
68
                 data_root: str,
                 ann_file: str,
69
                 pipeline: List[Union[dict, Callable]] = [],
VVsssssk's avatar
VVsssssk committed
70
                 box_type_3d: str = 'LiDAR',
71
                 load_type: str = 'frame_based',
72
                 modality: dict = dict(
VVsssssk's avatar
VVsssssk committed
73
74
75
76
77
78
79
                     use_camera=False,
                     use_lidar=True,
                 ),
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 with_velocity: bool = True,
                 use_valid_flag: bool = False,
80
                 **kwargs) -> None:
yinchimaoliang's avatar
yinchimaoliang committed
81
        self.use_valid_flag = use_valid_flag
VVsssssk's avatar
VVsssssk committed
82
        self.with_velocity = with_velocity
ZCMax's avatar
ZCMax committed
83
84

        # TODO: Redesign multi-view data process in the future
85
86
87
        assert load_type in ('frame_based', 'mv_image_based',
                             'fov_image_based')
        self.load_type = load_type
ZCMax's avatar
ZCMax committed
88
89

        assert box_type_3d.lower() in ('lidar', 'camera')
zhangwenwei's avatar
zhangwenwei committed
90
91
92
93
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            modality=modality,
VVsssssk's avatar
VVsssssk committed
94
            pipeline=pipeline,
95
96
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
VVsssssk's avatar
VVsssssk committed
97
98
            test_mode=test_mode,
            **kwargs)
yinchimaoliang's avatar
yinchimaoliang committed
99

100
    def _filter_with_mask(self, ann_info: dict) -> dict:
ChaimZhu's avatar
ChaimZhu committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        """Remove annotations that do not need to be cared.

        Args:
            ann_info (dict): Dict of annotation infos.

        Returns:
            dict: Annotations after filtering.
        """
        filtered_annotations = {}
        if self.use_valid_flag:
            filter_mask = ann_info['bbox_3d_isvalid']
        else:
            filter_mask = ann_info['num_lidar_pts'] > 0
        for key in ann_info.keys():
            if key != 'instances':
                filtered_annotations[key] = (ann_info[key][filter_mask])
            else:
                filtered_annotations[key] = ann_info[key]
        return filtered_annotations

VVsssssk's avatar
VVsssssk committed
121
    def parse_ann_info(self, info: dict) -> dict:
122
        """Process the `instances` in data info to `ann_info`.
123
124

        Args:
VVsssssk's avatar
VVsssssk committed
125
            info (dict): Data information of single data sample.
126
127

        Returns:
128
            dict: Annotation information consists of the following keys:
129

130
                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
131
                  3D ground truth bboxes.
wangtai's avatar
wangtai committed
132
                - gt_labels_3d (np.ndarray): Labels of ground truths.
133
        """
VVsssssk's avatar
VVsssssk committed
134
        ann_info = super().parse_ann_info(info)
ChaimZhu's avatar
ChaimZhu committed
135
136
137
138
139
140
141
142
143
144
145
146
        if ann_info is not None:

            ann_info = self._filter_with_mask(ann_info)

            if self.with_velocity:
                gt_bboxes_3d = ann_info['gt_bboxes_3d']
                gt_velocities = ann_info['velocities']
                nan_mask = np.isnan(gt_velocities[:, 0])
                gt_velocities[nan_mask] = [0.0, 0.0]
                gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
                                              axis=-1)
                ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
147
        else:
ChaimZhu's avatar
ChaimZhu committed
148
149
150
151
152
153
154
155
            # empty instance
            ann_info = dict()
            if self.with_velocity:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
            else:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)

156
            if self.load_type in ['fov_image_based', 'mv_image_based']:
ChaimZhu's avatar
ChaimZhu committed
157
158
159
160
161
                ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
                ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
                ann_info['attr_labels'] = np.array(0, dtype=np.int64)
                ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
                ann_info['depths'] = np.zeros((0), dtype=np.float32)
zhangwenwei's avatar
zhangwenwei committed
162

wangtai's avatar
wangtai committed
163
        # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
wuyuefeng's avatar
wuyuefeng committed
164
        # the same as KITTI (0.5, 0.5, 0)
ZCMax's avatar
ZCMax committed
165
        # TODO: Unify the coordinates
166
        if self.load_type in ['fov_image_based', 'mv_image_based']:
ZCMax's avatar
ZCMax committed
167
            gt_bboxes_3d = CameraInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
168
169
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
170
171
172
                origin=(0.5, 0.5, 0.5))
        else:
            gt_bboxes_3d = LiDARInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
173
174
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
175
                origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
zhangwenwei's avatar
zhangwenwei committed
176

ChaimZhu's avatar
ChaimZhu committed
177
        ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
178

ChaimZhu's avatar
ChaimZhu committed
179
        return ann_info
ZCMax's avatar
ZCMax committed
180

181
    def parse_data_info(self, info: dict) -> Union[List[dict], dict]:
ZCMax's avatar
ZCMax committed
182
183
184
185
186
187
188
189
190
        """Process the raw data info.

        The only difference with it in `Det3DDataset`
        is the specific process for `plane`.

        Args:
            info (dict): Raw info dict.

        Returns:
191
            List[dict] or dict: Has `ann_info` in training stage. And
ZCMax's avatar
ZCMax committed
192
193
            all path has been converted to absolute path.
        """
194
        if self.load_type == 'mv_image_based':
ZCMax's avatar
ZCMax committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            data_list = []
            if self.modality['use_lidar']:
                info['lidar_points']['lidar_path'] = \
                    osp.join(
                        self.data_prefix.get('pts', ''),
                        info['lidar_points']['lidar_path'])

            if self.modality['use_camera']:
                for cam_id, img_info in info['images'].items():
                    if 'img_path' in img_info:
                        if cam_id in self.data_prefix:
                            cam_prefix = self.data_prefix[cam_id]
                        else:
                            cam_prefix = self.data_prefix.get('img', '')
                        img_info['img_path'] = osp.join(
                            cam_prefix, img_info['img_path'])

            for idx, (cam_id, img_info) in enumerate(info['images'].items()):
                camera_info = dict()
                camera_info['images'] = dict()
                camera_info['images'][cam_id] = img_info
                if 'cam_instances' in info and cam_id in info['cam_instances']:
                    camera_info['instances'] = info['cam_instances'][cam_id]
                else:
                    camera_info['instances'] = []
                # TODO: check whether to change sample_idx for 6 cameras
                #  in one frame
                camera_info['sample_idx'] = info['sample_idx'] * 6 + idx
                camera_info['token'] = info['token']
                camera_info['ego2global'] = info['ego2global']

                if not self.test_mode:
                    # used in traing
                    camera_info['ann_info'] = self.parse_ann_info(camera_info)
                if self.test_mode and self.load_eval_anns:
                    camera_info['eval_ann_info'] = \
                        self.parse_ann_info(camera_info)
                data_list.append(camera_info)
            return data_list
        else:
            data_info = super().parse_data_info(info)
            return data_info