nuscenes_dataset.py 8.32 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from os import path as osp
VVsssssk's avatar
VVsssssk committed
3
from typing import Dict, List
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.structures import LiDARInstance3DBoxes
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class NuScenesDataset(Det3DDataset):
wangtai's avatar
wangtai committed
15
    r"""NuScenes Dataset.
wangtai's avatar
wangtai committed
16
17
18

    This class serves as the API for experiments on the NuScenes Dataset.

zhangwenwei's avatar
zhangwenwei committed
19
20
    Please refer to `NuScenes Dataset <https://www.nuscenes.org/download>`_
    for data downloading.
wangtai's avatar
wangtai committed
21
22

    Args:
VVsssssk's avatar
VVsssssk committed
23
        data_root (str): Path of dataset root.
wangtai's avatar
wangtai committed
24
25
26
        ann_file (str): Path of annotation file.
        pipeline (list[dict], optional): Pipeline used for data processing.
            Defaults to None.
VVsssssk's avatar
VVsssssk committed
27
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
28
29
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
yinchimaoliang's avatar
yinchimaoliang committed
30
            Defaults to 'LiDAR' in this dataset. Available options includes.
VVsssssk's avatar
VVsssssk committed
31

wangtai's avatar
wangtai committed
32
33
34
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
VVsssssk's avatar
VVsssssk committed
35
36
37
        modality (dict, optional): Modality to specify the sensor data used
            as input. Defaults to dict(use_camera=False,use_lidar=True).
        filter_empty_gt (bool): Whether to filter empty GT.
wangtai's avatar
wangtai committed
38
            Defaults to True.
VVsssssk's avatar
VVsssssk committed
39
        test_mode (bool): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
40
            Defaults to False.
VVsssssk's avatar
VVsssssk committed
41
42
43
        with_velocity (bool): Whether include velocity prediction
            into the experiments. Defaults to True.
        use_valid_flag (bool): Whether to use `use_valid_flag` key
44
45
            in the info file as mask to filter gt_boxes and gt_names.
            Defaults to False.
wangtai's avatar
wangtai committed
46
    """
VVsssssk's avatar
VVsssssk committed
47
48
49
50
51
52
    METAINFO = {
        'CLASSES':
        ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
         'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
        'version':
        'v1.0-trainval'
zhangwenwei's avatar
zhangwenwei committed
53
54
55
    }

    def __init__(self,
VVsssssk's avatar
VVsssssk committed
56
57
                 data_root: str,
                 ann_file: str,
ZCMax's avatar
ZCMax committed
58
                 task: str = '3d',
VVsssssk's avatar
VVsssssk committed
59
60
61
62
63
64
65
66
67
68
69
                 pipeline: List[dict] = None,
                 box_type_3d: str = 'LiDAR',
                 modality: Dict = dict(
                     use_camera=False,
                     use_lidar=True,
                 ),
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 with_velocity: bool = True,
                 use_valid_flag: bool = False,
                 **kwargs):
yinchimaoliang's avatar
yinchimaoliang committed
70
        self.use_valid_flag = use_valid_flag
VVsssssk's avatar
VVsssssk committed
71
        self.with_velocity = with_velocity
ZCMax's avatar
ZCMax committed
72
73
74
75
76
77

        # TODO: Redesign multi-view data process in the future
        assert task in ('3d', 'mono3d', 'multi-view')
        self.task = task

        assert box_type_3d.lower() in ('lidar', 'camera')
zhangwenwei's avatar
zhangwenwei committed
78
79
80
81
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            modality=modality,
VVsssssk's avatar
VVsssssk committed
82
            pipeline=pipeline,
83
84
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
VVsssssk's avatar
VVsssssk committed
85
86
            test_mode=test_mode,
            **kwargs)
yinchimaoliang's avatar
yinchimaoliang committed
87

VVsssssk's avatar
VVsssssk committed
88
    def parse_ann_info(self, info: dict) -> dict:
89
90
91
        """Get annotation info according to the given index.

        Args:
VVsssssk's avatar
VVsssssk committed
92
            info (dict): Data information of single data sample.
93
94

        Returns:
VVsssssk's avatar
VVsssssk committed
95
            dict: annotation information consists of the following keys:
96

97
                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
VVsssssk's avatar
VVsssssk committed
98
                    3D ground truth bboxes.
wangtai's avatar
wangtai committed
99
                - gt_labels_3d (np.ndarray): Labels of ground truths.
100
        """
VVsssssk's avatar
VVsssssk committed
101
102
103
104
105
106
107
        ann_info = super().parse_ann_info(info)
        if ann_info is None:
            # empty instance
            anns_results = dict()
            anns_results['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
            anns_results['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
            return anns_results
ZCMax's avatar
ZCMax committed
108

yinchimaoliang's avatar
yinchimaoliang committed
109
        if self.use_valid_flag:
VVsssssk's avatar
VVsssssk committed
110
            mask = ann_info['bbox_3d_isvalid']
yinchimaoliang's avatar
yinchimaoliang committed
111
        else:
VVsssssk's avatar
VVsssssk committed
112
113
114
            mask = ann_info['num_lidar_pts'] > 0
        gt_bboxes_3d = ann_info['gt_bboxes_3d'][mask]
        gt_labels_3d = ann_info['gt_labels_3d'][mask]
zhangwenwei's avatar
zhangwenwei committed
115

ZCMax's avatar
ZCMax committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if 'gt_bboxes' in ann_info:
            gt_bboxes = ann_info['gt_bboxes'][mask]
            gt_labels = ann_info['gt_labels'][mask]
            attr_labels = ann_info['attr_labels'][mask]
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)
            attr_labels = np.array([], dtype=np.int64)

        if 'centers_2d' in ann_info:
            centers_2d = ann_info['centers_2d'][mask]
            depths = ann_info['depths'][mask]
        else:
            centers_2d = np.zeros((0, 2), dtype=np.float32)
            depths = np.zeros((0), dtype=np.float32)

zhangwenwei's avatar
zhangwenwei committed
132
        if self.with_velocity:
VVsssssk's avatar
VVsssssk committed
133
            gt_velocity = ann_info['velocity'][mask]
zhangwenwei's avatar
zhangwenwei committed
134
135
136
137
            nan_mask = np.isnan(gt_velocity[:, 0])
            gt_velocity[nan_mask] = [0.0, 0.0]
            gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)

wangtai's avatar
wangtai committed
138
        # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
wuyuefeng's avatar
wuyuefeng committed
139
        # the same as KITTI (0.5, 0.5, 0)
ZCMax's avatar
ZCMax committed
140
141
142
143
144
145
146
147
148
149
150
        # TODO: Unify the coordinates
        if self.task == 'mono3d':
            gt_bboxes_3d = CameraInstance3DBoxes(
                gt_bboxes_3d,
                box_dim=gt_bboxes_3d.shape[-1],
                origin=(0.5, 0.5, 0.5))
        else:
            gt_bboxes_3d = LiDARInstance3DBoxes(
                gt_bboxes_3d,
                box_dim=gt_bboxes_3d.shape[-1],
                origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
zhangwenwei's avatar
zhangwenwei committed
151

zhangwenwei's avatar
zhangwenwei committed
152
        anns_results = dict(
ZCMax's avatar
ZCMax committed
153
154
155
156
157
158
159
160
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
            gt_bboxes=gt_bboxes,
            gt_labels=gt_labels,
            attr_labels=attr_labels,
            centers_2d=centers_2d,
            depths=depths)

zhangwenwei's avatar
zhangwenwei committed
161
        return anns_results
ZCMax's avatar
ZCMax committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.

        The only difference with it in `Det3DDataset`
        is the specific process for `plane`.

        Args:
            info (dict): Raw info dict.

        Returns:
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
        """
        if self.task == 'mono3d':
            data_list = []
            if self.modality['use_lidar']:
                info['lidar_points']['lidar_path'] = \
                    osp.join(
                        self.data_prefix.get('pts', ''),
                        info['lidar_points']['lidar_path'])

            if self.modality['use_camera']:
                for cam_id, img_info in info['images'].items():
                    if 'img_path' in img_info:
                        if cam_id in self.data_prefix:
                            cam_prefix = self.data_prefix[cam_id]
                        else:
                            cam_prefix = self.data_prefix.get('img', '')
                        img_info['img_path'] = osp.join(
                            cam_prefix, img_info['img_path'])

            for idx, (cam_id, img_info) in enumerate(info['images'].items()):
                camera_info = dict()
                camera_info['images'] = dict()
                camera_info['images'][cam_id] = img_info
                if 'cam_instances' in info and cam_id in info['cam_instances']:
                    camera_info['instances'] = info['cam_instances'][cam_id]
                else:
                    camera_info['instances'] = []
                # TODO: check whether to change sample_idx for 6 cameras
                #  in one frame
                camera_info['sample_idx'] = info['sample_idx'] * 6 + idx
                camera_info['token'] = info['token']
                camera_info['ego2global'] = info['ego2global']

                if not self.test_mode:
                    # used in traing
                    camera_info['ann_info'] = self.parse_ann_info(camera_info)
                if self.test_mode and self.load_eval_anns:
                    camera_info['eval_ann_info'] = \
                        self.parse_ann_info(camera_info)
                data_list.append(camera_info)
            return data_list
        else:
            data_info = super().parse_data_info(info)
            return data_info