nuscenes_dataset.py 8.98 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from os import path as osp
3
from typing import Callable, List, Union
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.structures import LiDARInstance3DBoxes
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class NuScenesDataset(Det3DDataset):
wangtai's avatar
wangtai committed
15
    r"""NuScenes Dataset.
wangtai's avatar
wangtai committed
16
17
18

    This class serves as the API for experiments on the NuScenes Dataset.

zhangwenwei's avatar
zhangwenwei committed
19
20
    Please refer to `NuScenes Dataset <https://www.nuscenes.org/download>`_
    for data downloading.
wangtai's avatar
wangtai committed
21
22

    Args:
VVsssssk's avatar
VVsssssk committed
23
        data_root (str): Path of dataset root.
wangtai's avatar
wangtai committed
24
        ann_file (str): Path of annotation file.
25
26
27
        task (str): Detection task. Defaults to 'lidar_det'.
        pipeline (list[dict]): Pipeline used for data processing.
            Defaults to [].
VVsssssk's avatar
VVsssssk committed
28
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
29
30
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
31
            Defaults to 'LiDAR' in this dataset. Available options includes:
VVsssssk's avatar
VVsssssk committed
32

wangtai's avatar
wangtai committed
33
34
35
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
36
37
38
39
40
41
42
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
        test_mode (bool): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
43
            Defaults to False.
44
        with_velocity (bool): Whether to include velocity prediction
VVsssssk's avatar
VVsssssk committed
45
            into the experiments. Defaults to True.
46
        use_valid_flag (bool): Whether to use `use_valid_flag` key
47
48
            in the info file as mask to filter gt_boxes and gt_names.
            Defaults to False.
wangtai's avatar
wangtai committed
49
    """
VVsssssk's avatar
VVsssssk committed
50
51
52
53
54
55
    METAINFO = {
        'CLASSES':
        ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
         'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
        'version':
        'v1.0-trainval'
zhangwenwei's avatar
zhangwenwei committed
56
57
58
    }

    def __init__(self,
VVsssssk's avatar
VVsssssk committed
59
60
                 data_root: str,
                 ann_file: str,
61
                 task: str = 'lidar_det',
62
                 pipeline: List[Union[dict, Callable]] = [],
VVsssssk's avatar
VVsssssk committed
63
                 box_type_3d: str = 'LiDAR',
64
                 modality: dict = dict(
VVsssssk's avatar
VVsssssk committed
65
66
67
68
69
70
71
                     use_camera=False,
                     use_lidar=True,
                 ),
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 with_velocity: bool = True,
                 use_valid_flag: bool = False,
72
                 **kwargs) -> None:
yinchimaoliang's avatar
yinchimaoliang committed
73
        self.use_valid_flag = use_valid_flag
VVsssssk's avatar
VVsssssk committed
74
        self.with_velocity = with_velocity
ZCMax's avatar
ZCMax committed
75
76

        # TODO: Redesign multi-view data process in the future
77
        assert task in ('lidar_det', 'mono_det', 'multi-view_det')
ZCMax's avatar
ZCMax committed
78
79
80
        self.task = task

        assert box_type_3d.lower() in ('lidar', 'camera')
zhangwenwei's avatar
zhangwenwei committed
81
82
83
84
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            modality=modality,
VVsssssk's avatar
VVsssssk committed
85
            pipeline=pipeline,
86
87
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
VVsssssk's avatar
VVsssssk committed
88
89
            test_mode=test_mode,
            **kwargs)
yinchimaoliang's avatar
yinchimaoliang committed
90

91
    def _filter_with_mask(self, ann_info: dict) -> dict:
ChaimZhu's avatar
ChaimZhu committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        """Remove annotations that do not need to be cared.

        Args:
            ann_info (dict): Dict of annotation infos.

        Returns:
            dict: Annotations after filtering.
        """
        filtered_annotations = {}
        if self.use_valid_flag:
            filter_mask = ann_info['bbox_3d_isvalid']
        else:
            filter_mask = ann_info['num_lidar_pts'] > 0
        for key in ann_info.keys():
            if key != 'instances':
                filtered_annotations[key] = (ann_info[key][filter_mask])
            else:
                filtered_annotations[key] = ann_info[key]
        return filtered_annotations

VVsssssk's avatar
VVsssssk committed
112
    def parse_ann_info(self, info: dict) -> dict:
113
        """Process the `instances` in data info to `ann_info`.
114
115

        Args:
VVsssssk's avatar
VVsssssk committed
116
            info (dict): Data information of single data sample.
117
118

        Returns:
119
            dict: Annotation information consists of the following keys:
120

121
                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
122
                  3D ground truth bboxes.
wangtai's avatar
wangtai committed
123
                - gt_labels_3d (np.ndarray): Labels of ground truths.
124
        """
VVsssssk's avatar
VVsssssk committed
125
        ann_info = super().parse_ann_info(info)
ChaimZhu's avatar
ChaimZhu committed
126
127
128
129
130
131
132
133
134
135
136
137
        if ann_info is not None:

            ann_info = self._filter_with_mask(ann_info)

            if self.with_velocity:
                gt_bboxes_3d = ann_info['gt_bboxes_3d']
                gt_velocities = ann_info['velocities']
                nan_mask = np.isnan(gt_velocities[:, 0])
                gt_velocities[nan_mask] = [0.0, 0.0]
                gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
                                              axis=-1)
                ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
138
        else:
ChaimZhu's avatar
ChaimZhu committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            # empty instance
            ann_info = dict()
            if self.with_velocity:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
            else:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)

            if self.task == 'mono3d':
                ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
                ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
                ann_info['attr_labels'] = np.array(0, dtype=np.int64)
                ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
                ann_info['depths'] = np.zeros((0), dtype=np.float32)
zhangwenwei's avatar
zhangwenwei committed
153

wangtai's avatar
wangtai committed
154
        # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
wuyuefeng's avatar
wuyuefeng committed
155
        # the same as KITTI (0.5, 0.5, 0)
ZCMax's avatar
ZCMax committed
156
        # TODO: Unify the coordinates
157
        if self.task == 'mono_det':
ZCMax's avatar
ZCMax committed
158
            gt_bboxes_3d = CameraInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
159
160
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
161
162
163
                origin=(0.5, 0.5, 0.5))
        else:
            gt_bboxes_3d = LiDARInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
164
165
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
166
                origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
zhangwenwei's avatar
zhangwenwei committed
167

ChaimZhu's avatar
ChaimZhu committed
168
        ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
169

ChaimZhu's avatar
ChaimZhu committed
170
        return ann_info
ZCMax's avatar
ZCMax committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.

        The only difference with it in `Det3DDataset`
        is the specific process for `plane`.

        Args:
            info (dict): Raw info dict.

        Returns:
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
        """
185
        if self.task == 'mono_det':
ZCMax's avatar
ZCMax committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            data_list = []
            if self.modality['use_lidar']:
                info['lidar_points']['lidar_path'] = \
                    osp.join(
                        self.data_prefix.get('pts', ''),
                        info['lidar_points']['lidar_path'])

            if self.modality['use_camera']:
                for cam_id, img_info in info['images'].items():
                    if 'img_path' in img_info:
                        if cam_id in self.data_prefix:
                            cam_prefix = self.data_prefix[cam_id]
                        else:
                            cam_prefix = self.data_prefix.get('img', '')
                        img_info['img_path'] = osp.join(
                            cam_prefix, img_info['img_path'])

            for idx, (cam_id, img_info) in enumerate(info['images'].items()):
                camera_info = dict()
                camera_info['images'] = dict()
                camera_info['images'][cam_id] = img_info
                if 'cam_instances' in info and cam_id in info['cam_instances']:
                    camera_info['instances'] = info['cam_instances'][cam_id]
                else:
                    camera_info['instances'] = []
                # TODO: check whether to change sample_idx for 6 cameras
                #  in one frame
                camera_info['sample_idx'] = info['sample_idx'] * 6 + idx
                camera_info['token'] = info['token']
                camera_info['ego2global'] = info['ego2global']

                if not self.test_mode:
                    # used in traing
                    camera_info['ann_info'] = self.parse_ann_info(camera_info)
                if self.test_mode and self.load_eval_anns:
                    camera_info['eval_ann_info'] = \
                        self.parse_ann_info(camera_info)
                data_list.append(camera_info)
            return data_list
        else:
            data_info = super().parse_data_info(info)
            return data_info