waymo_dataset.py 12.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import os.path as osp
3
from typing import Callable, List, Union
4
5

import numpy as np
6
7
from mmengine import print_log
from mmengine.fileio import load
Wenwei Zhang's avatar
Wenwei Zhang committed
8

9
from mmdet3d.registry import DATASETS
10
from mmdet3d.structures import CameraInstance3DBoxes, LiDARInstance3DBoxes
11
from .det3d_dataset import Det3DDataset
Wenwei Zhang's avatar
Wenwei Zhang committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from .kitti_dataset import KittiDataset


@DATASETS.register_module()
class WaymoDataset(KittiDataset):
    """Waymo Dataset.

    This class serves as the API for experiments on the Waymo Dataset.

    Please refer to `<https://waymo.com/open/download/>`_for data downloading.
    It is recommended to symlink the dataset root to $MMDETECTION3D/data and
    organize them as the doc shows.

    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
28
        data_prefix (dict): data prefix for point cloud and
29
            camera data dict. Defaults to dict(
30
31
                                    pts='velodyne',
                                    CAM_FRONT='image_0',
32
33
34
35
                                    CAM_FRONT_LEFT='image_1',
                                    CAM_FRONT_RIGHT='image_2',
                                    CAM_SIDE_LEFT='image_3',
                                    CAM_SIDE_RIGHT='image_4')
36
        pipeline (List[dict]): Pipeline used for data processing.
37
38
            Defaults to [].
        modality (dict): Modality to specify the sensor data used
39
            as input. Defaults to dict(use_lidar=True).
40
        default_cam_key (str): Default camera key for lidar2img
41
            association. Defaults to 'CAM_FRONT'.
42
        box_type_3d (str): Type of 3D box of this dataset.
Wenwei Zhang's avatar
Wenwei Zhang committed
43
44
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
45
46
            Defaults to 'LiDAR' in this dataset. Available options includes:

47
48
49
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
50
51
52
53
54
55
56
57
58
        load_type (str): Type of loading mode. Defaults to 'frame_based'.

            - 'frame_based': Load all of the instances in the frame.
            - 'mv_image_based': Load all of the instances in the frame and need
                to convert to the FOV-based data type to support image-based
                detector.
            - 'fov_image_based': Only load the instances inside the default
                cam, and need to convert to the FOV-based data type to support
                image-based detector.
59
60
61
62
63
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
        test_mode (bool): Whether the dataset is in test mode.
Wenwei Zhang's avatar
Wenwei Zhang committed
64
            Defaults to False.
65
        pcd_limit_range (List[float]): The range of point cloud
66
67
            used to filter invalid predicted boxes.
            Defaults to [-85, -85, -5, 85, 85, 5].
68
        cam_sync_instances (bool): If use the camera sync label
69
            supported from waymo version 1.3.1. Defaults to False.
70
71
        load_interval (int): load frame interval. Defaults to 1.
        max_sweeps (int): max sweep for each frame. Defaults to 0.
Wenwei Zhang's avatar
Wenwei Zhang committed
72
    """
73
74
75
76
77
78
79
80
    METAINFO = {
        'classes': ('Car', 'Pedestrian', 'Cyclist'),
        'palette': [
            (0, 120, 255),  # Waymo Blue
            (0, 232, 157),  # Waymo Green
            (255, 205, 85)  # Amber
        ]
    }
Wenwei Zhang's avatar
Wenwei Zhang committed
81
82

    def __init__(self,
83
84
85
86
87
                 data_root: str,
                 ann_file: str,
                 data_prefix: dict = dict(
                     pts='velodyne',
                     CAM_FRONT='image_0',
88
89
90
91
                     CAM_FRONT_LEFT='image_1',
                     CAM_FRONT_RIGHT='image_2',
                     CAM_SIDE_LEFT='image_3',
                     CAM_SIDE_RIGHT='image_4'),
92
                 pipeline: List[Union[dict, Callable]] = [],
93
                 modality: dict = dict(use_lidar=True),
94
95
                 default_cam_key: str = 'CAM_FRONT',
                 box_type_3d: str = 'LiDAR',
96
                 load_type: str = 'frame_based',
97
98
99
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0],
100
101
102
103
                 cam_sync_instances: bool = False,
                 load_interval: int = 1,
                 max_sweeps: int = 0,
                 **kwargs) -> None:
104
105
106
107
        self.load_interval = load_interval
        # set loading mode for different task settings
        self.cam_sync_instances = cam_sync_instances
        # construct self.cat_ids for vision-only anns parsing
108
        self.cat_ids = range(len(self.METAINFO['classes']))
109
110
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.max_sweeps = max_sweeps
111
        # we do not provide backend_args to custom_3d init
112
        # because we want disk loading for info
113
        # while ceph loading for Prediction2Waymo
Wenwei Zhang's avatar
Wenwei Zhang committed
114
115
116
117
118
119
120
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            pipeline=pipeline,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
121
            pcd_limit_range=pcd_limit_range,
122
123
124
            default_cam_key=default_cam_key,
            data_prefix=data_prefix,
            test_mode=test_mode,
125
            load_type=load_type,
126
            **kwargs)
Wenwei Zhang's avatar
Wenwei Zhang committed
127

128
    def parse_ann_info(self, info: dict) -> dict:
129
        """Process the `instances` in data info to `ann_info`.
Wenwei Zhang's avatar
Wenwei Zhang committed
130
131

        Args:
132
            info (dict): Data information of single data sample.
Wenwei Zhang's avatar
Wenwei Zhang committed
133
134

        Returns:
135
            dict: Annotation information consists of the following keys:
136
137

                - bboxes_3d (:obj:`LiDARInstance3DBoxes`):
138
                  3D ground truth bboxes.
139
140
141
142
                - bbox_labels_3d (np.ndarray): Labels of ground truths.
                - gt_bboxes (np.ndarray): 2D ground truth bboxes.
                - gt_labels (np.ndarray): Labels of ground truths.
                - difficulty (int): Difficulty defined by KITTI.
143
                  0, 1, 2 represent xxxxx respectively.
Wenwei Zhang's avatar
Wenwei Zhang committed
144
        """
145
146
147
        ann_info = Det3DDataset.parse_ann_info(self, info)
        if ann_info is None:
            # empty instance
148
149
150
            ann_info = {}
            ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
151
152
153
154
155
156

        ann_info = self._remove_dontcare(ann_info)
        # in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam
        # convert gt_bboxes_3d to velodyne coordinates with `lidar2cam`
        if 'gt_bboxes' in ann_info:
            gt_bboxes = ann_info['gt_bboxes']
157
            gt_bboxes_labels = ann_info['gt_bboxes_labels']
Wenwei Zhang's avatar
Wenwei Zhang committed
158
        else:
159
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
160
            gt_bboxes_labels = np.zeros(0, dtype=np.int64)
161
162
163
        if 'centers_2d' in ann_info:
            centers_2d = ann_info['centers_2d']
            depths = ann_info['depths']
Wenwei Zhang's avatar
Wenwei Zhang committed
164
        else:
165
166
            centers_2d = np.zeros((0, 2), dtype=np.float32)
            depths = np.zeros((0), dtype=np.float32)
Wenwei Zhang's avatar
Wenwei Zhang committed
167

168
169
170
171
        if self.load_type == 'frame_based':
            gt_bboxes_3d = LiDARInstance3DBoxes(ann_info['gt_bboxes_3d'])
        else:
            gt_bboxes_3d = CameraInstance3DBoxes(ann_info['gt_bboxes_3d'])
172
173
174
175
176

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=ann_info['gt_labels_3d'],
            gt_bboxes=gt_bboxes,
177
            gt_bboxes_labels=gt_bboxes_labels,
178
179
180
181
182
183
            centers_2d=centers_2d,
            depths=depths)

        return anns_results

    def load_data_list(self) -> List[dict]:
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        """Add the load interval.

        Returns:
            list[dict]: A list of annotation.
        """  # noqa: E501
        # `self.ann_file` denotes the absolute annotation file path if
        # `self.root=None` or relative path if `self.root=/path/to/data/`.
        annotations = load(self.ann_file)
        if not isinstance(annotations, dict):
            raise TypeError(f'The annotations loaded from annotation file '
                            f'should be a dict, but got {type(annotations)}!')
        if 'data_list' not in annotations or 'metainfo' not in annotations:
            raise ValueError('Annotation must have data_list and metainfo '
                             'keys')
        metainfo = annotations['metainfo']
        raw_data_list = annotations['data_list']
        raw_data_list = raw_data_list[::self.load_interval]
        if self.load_interval > 1:
            print_log(
                f'Sample size will be reduced to 1/{self.load_interval} of'
                ' the original data sample',
                logger='current')

        # Meta information load from annotation file will not influence the
        # existed meta information load from `BaseDataset.METAINFO` and
        # `metainfo` arguments defined in constructor.
        for k, v in metainfo.items():
            self._metainfo.setdefault(k, v)

        # load and parse data_infos.
        data_list = []
        for raw_data_info in raw_data_list:
            # parse raw data information to target format
            data_info = self.parse_data_info(raw_data_info)
            if isinstance(data_info, dict):
                # For image tasks, `data_info` should information if single
                # image, such as dict(img_path='xxx', width=360, ...)
                data_list.append(data_info)
            elif isinstance(data_info, list):
                # For video tasks, `data_info` could contain image
                # information of multiple frames, such as
                # [dict(video_path='xxx', timestamps=...),
                #  dict(video_path='xxx', timestamps=...)]
                for item in data_info:
                    if not isinstance(item, dict):
                        raise TypeError('data_info must be list of dict, but '
                                        f'got {type(item)}')
                data_list.extend(data_info)
            else:
                raise TypeError('data_info should be a dict or list of dict, '
                                f'but got {type(data_info)}')

236
237
        return data_list

238
    def parse_data_info(self, info: dict) -> Union[dict, List[dict]]:
239
240
        """if task is lidar or multiview det, use super() method elif task is
        mono3d, split the info from frame-wise to img-wise."""
241
242
243
244
245
246
247
248
249
250
251
252
253

        if self.cam_sync_instances:
            info['instances'] = info['cam_sync_instances']

        if self.load_type == 'frame_based':
            return super().parse_data_info(info)
        elif self.load_type == 'fov_image_based':
            # only loading the fov image and the fov instance
            new_image_info = {}
            new_image_info[self.default_cam_key] = \
                info['images'][self.default_cam_key]
            info['images'] = new_image_info
            info['instances'] = info['cam_instances'][self.default_cam_key]
254
            return Det3DDataset.parse_data_info(self, info)
255
256
        else:
            # in the mono3d, the instances is from cam sync.
257
            # Convert frame-based infos to multi-view image-based
258
259
260
            data_list = []
            for (cam_key, img_info) in info['images'].items():
                camera_info = dict()
261
262
263
                camera_info['sample_idx'] = info['sample_idx']
                camera_info['timestamp'] = info['timestamp']
                camera_info['context_name'] = info['context_name']
264
265
                camera_info['images'] = dict()
                camera_info['images'][cam_key] = img_info
266
267
268
269
270
271
272
273
274
275
                if 'img_path' in img_info:
                    cam_prefix = self.data_prefix.get(cam_key, '')
                    camera_info['images'][cam_key]['img_path'] = osp.join(
                        cam_prefix, img_info['img_path'])
                if 'lidar2cam' in img_info:
                    camera_info['lidar2cam'] = np.array(img_info['lidar2cam'])
                if 'cam2img' in img_info:
                    camera_info['cam2img'] = np.array(img_info['cam2img'])
                if 'lidar2img' in img_info:
                    camera_info['lidar2img'] = np.array(img_info['lidar2img'])
276
                else:
277
278
                    camera_info['lidar2img'] = camera_info[
                        'cam2img'] @ camera_info['lidar2cam']
279
280
281

                if not self.test_mode:
                    # used in training
282
                    camera_info['instances'] = info['cam_instances'][cam_key]
283
284
                    camera_info['ann_info'] = self.parse_ann_info(camera_info)
                if self.test_mode and self.load_eval_anns:
285
286
287
                    camera_info['instances'] = info['cam_instances'][cam_key]
                    camera_info['eval_ann_info'] = self.parse_ann_info(
                        camera_info)
288
289
                data_list.append(camera_info)
            return data_list