nuscenes_dataset.py 8.87 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from os import path as osp
3
from typing import Callable, List, Union
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.structures import LiDARInstance3DBoxes
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class NuScenesDataset(Det3DDataset):
wangtai's avatar
wangtai committed
15
    r"""NuScenes Dataset.
wangtai's avatar
wangtai committed
16
17
18

    This class serves as the API for experiments on the NuScenes Dataset.

zhangwenwei's avatar
zhangwenwei committed
19
20
    Please refer to `NuScenes Dataset <https://www.nuscenes.org/download>`_
    for data downloading.
wangtai's avatar
wangtai committed
21
22

    Args:
VVsssssk's avatar
VVsssssk committed
23
        data_root (str): Path of dataset root.
wangtai's avatar
wangtai committed
24
        ann_file (str): Path of annotation file.
25
        task (str, optional): Detection task. Defaults to 'lidar_det'.
wangtai's avatar
wangtai committed
26
27
        pipeline (list[dict], optional): Pipeline used for data processing.
            Defaults to None.
VVsssssk's avatar
VVsssssk committed
28
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
29
30
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
31
            Defaults to 'LiDAR' in this dataset. Available options includes:
VVsssssk's avatar
VVsssssk committed
32

wangtai's avatar
wangtai committed
33
34
35
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
VVsssssk's avatar
VVsssssk committed
36
        modality (dict, optional): Modality to specify the sensor data used
37
38
            as input. Defaults to dict(use_camera=False, use_lidar=True).
        filter_empty_gt (bool, optional): Whether to filter empty GT.
wangtai's avatar
wangtai committed
39
            Defaults to True.
40
        test_mode (bool, optional): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
41
            Defaults to False.
42
        with_velocity (bool, optional): Whether to include velocity prediction
VVsssssk's avatar
VVsssssk committed
43
            into the experiments. Defaults to True.
44
        use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
45
46
            in the info file as mask to filter gt_boxes and gt_names.
            Defaults to False.
wangtai's avatar
wangtai committed
47
    """
VVsssssk's avatar
VVsssssk committed
48
49
50
51
52
53
    METAINFO = {
        'CLASSES':
        ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
         'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
        'version':
        'v1.0-trainval'
zhangwenwei's avatar
zhangwenwei committed
54
55
56
    }

    def __init__(self,
VVsssssk's avatar
VVsssssk committed
57
58
                 data_root: str,
                 ann_file: str,
59
                 task: str = 'lidar_det',
60
                 pipeline: List[Union[dict, Callable]] = [],
VVsssssk's avatar
VVsssssk committed
61
                 box_type_3d: str = 'LiDAR',
62
                 modality: dict = dict(
VVsssssk's avatar
VVsssssk committed
63
64
65
66
67
68
69
                     use_camera=False,
                     use_lidar=True,
                 ),
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 with_velocity: bool = True,
                 use_valid_flag: bool = False,
70
                 **kwargs) -> None:
yinchimaoliang's avatar
yinchimaoliang committed
71
        self.use_valid_flag = use_valid_flag
VVsssssk's avatar
VVsssssk committed
72
        self.with_velocity = with_velocity
ZCMax's avatar
ZCMax committed
73
74

        # TODO: Redesign multi-view data process in the future
75
        assert task in ('lidar_det', 'mono_det', 'multi-view_det')
ZCMax's avatar
ZCMax committed
76
77
78
        self.task = task

        assert box_type_3d.lower() in ('lidar', 'camera')
zhangwenwei's avatar
zhangwenwei committed
79
80
81
82
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            modality=modality,
VVsssssk's avatar
VVsssssk committed
83
            pipeline=pipeline,
84
85
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
VVsssssk's avatar
VVsssssk committed
86
87
            test_mode=test_mode,
            **kwargs)
yinchimaoliang's avatar
yinchimaoliang committed
88

89
    def _filter_with_mask(self, ann_info: dict) -> dict:
ChaimZhu's avatar
ChaimZhu committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """Remove annotations that do not need to be cared.

        Args:
            ann_info (dict): Dict of annotation infos.

        Returns:
            dict: Annotations after filtering.
        """
        filtered_annotations = {}
        if self.use_valid_flag:
            filter_mask = ann_info['bbox_3d_isvalid']
        else:
            filter_mask = ann_info['num_lidar_pts'] > 0
        for key in ann_info.keys():
            if key != 'instances':
                filtered_annotations[key] = (ann_info[key][filter_mask])
            else:
                filtered_annotations[key] = ann_info[key]
        return filtered_annotations

VVsssssk's avatar
VVsssssk committed
110
    def parse_ann_info(self, info: dict) -> dict:
111
112
113
        """Get annotation info according to the given index.

        Args:
VVsssssk's avatar
VVsssssk committed
114
            info (dict): Data information of single data sample.
115
116

        Returns:
VVsssssk's avatar
VVsssssk committed
117
            dict: annotation information consists of the following keys:
118

119
                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
VVsssssk's avatar
VVsssssk committed
120
                    3D ground truth bboxes.
wangtai's avatar
wangtai committed
121
                - gt_labels_3d (np.ndarray): Labels of ground truths.
122
        """
VVsssssk's avatar
VVsssssk committed
123
        ann_info = super().parse_ann_info(info)
ChaimZhu's avatar
ChaimZhu committed
124
125
126
127
128
129
130
131
132
133
134
135
        if ann_info is not None:

            ann_info = self._filter_with_mask(ann_info)

            if self.with_velocity:
                gt_bboxes_3d = ann_info['gt_bboxes_3d']
                gt_velocities = ann_info['velocities']
                nan_mask = np.isnan(gt_velocities[:, 0])
                gt_velocities[nan_mask] = [0.0, 0.0]
                gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
                                              axis=-1)
                ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
136
        else:
ChaimZhu's avatar
ChaimZhu committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            # empty instance
            ann_info = dict()
            if self.with_velocity:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
            else:
                ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)

            if self.task == 'mono3d':
                ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
                ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
                ann_info['attr_labels'] = np.array(0, dtype=np.int64)
                ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
                ann_info['depths'] = np.zeros((0), dtype=np.float32)
zhangwenwei's avatar
zhangwenwei committed
151

wangtai's avatar
wangtai committed
152
        # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
wuyuefeng's avatar
wuyuefeng committed
153
        # the same as KITTI (0.5, 0.5, 0)
ZCMax's avatar
ZCMax committed
154
        # TODO: Unify the coordinates
155
        if self.task == 'mono_det':
ZCMax's avatar
ZCMax committed
156
            gt_bboxes_3d = CameraInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
157
158
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
159
160
161
                origin=(0.5, 0.5, 0.5))
        else:
            gt_bboxes_3d = LiDARInstance3DBoxes(
ChaimZhu's avatar
ChaimZhu committed
162
163
                ann_info['gt_bboxes_3d'],
                box_dim=ann_info['gt_bboxes_3d'].shape[-1],
ZCMax's avatar
ZCMax committed
164
                origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
zhangwenwei's avatar
zhangwenwei committed
165

ChaimZhu's avatar
ChaimZhu committed
166
        ann_info['gt_bboxes_3d'] = gt_bboxes_3d
ZCMax's avatar
ZCMax committed
167

ChaimZhu's avatar
ChaimZhu committed
168
        return ann_info
ZCMax's avatar
ZCMax committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.

        The only difference with it in `Det3DDataset`
        is the specific process for `plane`.

        Args:
            info (dict): Raw info dict.

        Returns:
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
        """
183
        if self.task == 'mono_det':
ZCMax's avatar
ZCMax committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            data_list = []
            if self.modality['use_lidar']:
                info['lidar_points']['lidar_path'] = \
                    osp.join(
                        self.data_prefix.get('pts', ''),
                        info['lidar_points']['lidar_path'])

            if self.modality['use_camera']:
                for cam_id, img_info in info['images'].items():
                    if 'img_path' in img_info:
                        if cam_id in self.data_prefix:
                            cam_prefix = self.data_prefix[cam_id]
                        else:
                            cam_prefix = self.data_prefix.get('img', '')
                        img_info['img_path'] = osp.join(
                            cam_prefix, img_info['img_path'])

            for idx, (cam_id, img_info) in enumerate(info['images'].items()):
                camera_info = dict()
                camera_info['images'] = dict()
                camera_info['images'][cam_id] = img_info
                if 'cam_instances' in info and cam_id in info['cam_instances']:
                    camera_info['instances'] = info['cam_instances'][cam_id]
                else:
                    camera_info['instances'] = []
                # TODO: check whether to change sample_idx for 6 cameras
                #  in one frame
                camera_info['sample_idx'] = info['sample_idx'] * 6 + idx
                camera_info['token'] = info['token']
                camera_info['ego2global'] = info['ego2global']

                if not self.test_mode:
                    # used in traing
                    camera_info['ann_info'] = self.parse_ann_info(camera_info)
                if self.test_mode and self.load_eval_anns:
                    camera_info['eval_ann_info'] = \
                        self.parse_ann_info(camera_info)
                data_list.append(camera_info)
            return data_list
        else:
            data_info = super().parse_data_info(info)
            return data_info