loading.py 29.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import List, Union
3

zhangwenwei's avatar
zhangwenwei committed
4
import mmcv
5
import mmengine
zhangwenwei's avatar
zhangwenwei committed
6
import numpy as np
7
from mmcv.transforms import LoadImageFromFile
8
from mmcv.transforms.base import BaseTransform
zhangwenwei's avatar
zhangwenwei committed
9

10
from mmdet3d.registry import TRANSFORMS
zhangshilong's avatar
zhangshilong committed
11
12
from mmdet3d.structures.points import BasePoints, get_points_type
from mmdet.datasets.transforms import LoadAnnotations
zhangwenwei's avatar
zhangwenwei committed
13
14


15
@TRANSFORMS.register_module()
16
class LoadMultiViewImageFromFiles(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
17
    """Load multi channel images from a list of separate channel files.
zhangwenwei's avatar
zhangwenwei committed
18

liyinhao's avatar
liyinhao committed
19
20
21
    Expects results['img_filename'] to be a list of filenames.

    Args:
22
        to_float32 (bool, optional): Whether to convert the img to float32.
liyinhao's avatar
liyinhao committed
23
            Defaults to False.
24
25
        color_type (str, optional): Color type of the file.
            Defaults to 'unchanged'.
zhangwenwei's avatar
zhangwenwei committed
26
    """
zhangwenwei's avatar
zhangwenwei committed
27

28
29
30
31
32
    def __init__(
        self,
        to_float32: bool = False,
        color_type: str = 'unchanged'
    ) -> None:
zhangwenwei's avatar
zhangwenwei committed
33
34
        self.to_float32 = to_float32
        self.color_type = color_type
zhangwenwei's avatar
zhangwenwei committed
35

36
    def transform(self, results: dict) -> dict:
37
38
39
40
41
42
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
43
            dict: The result dict containing the multi-view image data.
44
45
46
47
48
49
50
51
52
53
                Added keys and values are described below.

                - filename (str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
zhangwenwei's avatar
zhangwenwei committed
54
        filename = results['img_filename']
55
        # img is of shape (h, w, c, num_views)
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
        img = np.stack(
            [mmcv.imread(name, self.color_type) for name in filename], axis=-1)
        if self.to_float32:
            img = img.astype(np.float32)
        results['filename'] = filename
61
        # unravel to list, see `DefaultFormatBundle` in formatting.py
62
63
        # which will transpose each image separately and then stack into array
        results['img'] = [img[..., i] for i in range(img.shape[-1])]
zhangwenwei's avatar
zhangwenwei committed
64
65
66
67
68
69
70
71
72
73
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        # Set initial values for default meta_keys
        results['pad_shape'] = img.shape
        results['scale_factor'] = 1.0
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
zhangwenwei's avatar
zhangwenwei committed
74
75
76
        return results

    def __repr__(self):
77
        """str: Return a string that describes the module."""
78
79
80
81
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32}, '
        repr_str += f"color_type='{self.color_type}')"
        return repr_str
zhangwenwei's avatar
zhangwenwei committed
82
83


84
@TRANSFORMS.register_module()
85
86
87
88
89
class LoadImageFromFileMono3D(LoadImageFromFile):
    """Load an image from file in monocular 3D object detection. Compared to 2D
    detection, additional camera parameters need to be loaded.

    Args:
90
        kwargs (dict): Arguments are the same as those in
91
92
93
            :class:`LoadImageFromFile`.
    """

ZCMax's avatar
ZCMax committed
94
    def transform(self, results: dict) -> dict:
95
96
97
98
99
100
101
102
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmdet.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """
ZCMax's avatar
ZCMax committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # TODO: load different camera image from data info,
        # for kitti dataset, we load 'CAM2' image.
        # for nuscenes dataset, we load 'CAM_FRONT' image.

        if 'CAM2' in results['images']:
            filename = results['images']['CAM2']['img_path']
            results['cam2img'] = results['images']['CAM2']['cam2img']
        elif len(list(results['images'].keys())) == 1:
            camera_type = list(results['images'].keys())[0]
            filename = results['images'][camera_type]['img_path']
            results['cam2img'] = results['images'][camera_type]['cam2img']
        else:
            raise NotImplementedError(
                'Currently we only support load image from kitti and'
                'nuscenes datasets')

        img_bytes = self.file_client.get(filename)
        img = mmcv.imfrombytes(
            img_bytes, flag=self.color_type, backend=self.imdecode_backend)
        if self.to_float32:
            img = img.astype(np.float32)

        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]

129
130
131
        return results


132
@TRANSFORMS.register_module()
VVsssssk's avatar
VVsssssk committed
133
class LoadPointsFromMultiSweeps(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
134
    """Load points from multiple sweeps.
zhangwenwei's avatar
zhangwenwei committed
135

zhangwenwei's avatar
zhangwenwei committed
136
137
138
    This is usually used for nuScenes dataset to utilize previous sweeps.

    Args:
139
140
141
142
143
144
145
        sweeps_num (int, optional): Number of sweeps. Defaults to 10.
        load_dim (int, optional): Dimension number of the loaded points.
            Defaults to 5.
        use_dim (list[int], optional): Which dimension to use.
            Defaults to [0, 1, 2, 4].
        file_client_args (dict, optional): Config dict of file clients,
            refer to
146
            https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
liyinhao's avatar
liyinhao committed
147
            for more details. Defaults to dict(backend='disk').
148
        pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
149
            sweeps is empty. Defaults to False.
150
        remove_close (bool, optional): Whether to remove close points.
151
            Defaults to False.
152
        test_mode (bool, optional): If `test_mode=True`, it will not
153
154
            randomly sample sweeps but select the nearest N frames.
            Defaults to False.
zhangwenwei's avatar
zhangwenwei committed
155
156
    """

157
158
159
160
161
162
163
164
165
166
    def __init__(
        self,
        sweeps_num: int = 10,
        load_dim: int = 5,
        use_dim: List[int] = [0, 1, 2, 4],
        file_client_args: dict = dict(backend='disk'),
        pad_empty_sweeps: bool = False,
        remove_close: bool = False,
        test_mode: bool = False
    ) -> None:
zhangwenwei's avatar
zhangwenwei committed
167
        self.load_dim = load_dim
zhangwenwei's avatar
zhangwenwei committed
168
        self.sweeps_num = sweeps_num
169
        self.use_dim = use_dim
zhangwenwei's avatar
zhangwenwei committed
170
171
        self.file_client_args = file_client_args.copy()
        self.file_client = None
172
173
174
        self.pad_empty_sweeps = pad_empty_sweeps
        self.remove_close = remove_close
        self.test_mode = test_mode
zhangwenwei's avatar
zhangwenwei committed
175

176
    def _load_points(self, pts_filename: str) -> np.ndarray:
177
178
179
180
181
182
183
184
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
zhangwenwei's avatar
zhangwenwei committed
185
        if self.file_client is None:
186
            self.file_client = mmengine.FileClient(**self.file_client_args)
zhangwenwei's avatar
zhangwenwei committed
187
188
189
190
        try:
            pts_bytes = self.file_client.get(pts_filename)
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
191
            mmengine.check_file_exist(pts_filename)
zhangwenwei's avatar
zhangwenwei committed
192
193
194
195
196
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
        return points
zhangwenwei's avatar
zhangwenwei committed
197

198
199
200
201
202
    def _remove_close(
        self,
        points: Union[np.ndarray, BasePoints],
        radius: float = 1.0
    ) -> Union[np.ndarray, BasePoints]:
203
204
205
        """Removes point too close within a certain radius from origin.

        Args:
206
            points (np.ndarray | :obj:`BasePoints`): Sweep points.
207
            radius (float, optional): Radius below which points are removed.
208
209
210
                Defaults to 1.0.

        Returns:
211
            np.ndarray | :obj:`BasePoints`: Points after removing.
212
        """
213
214
215
216
217
218
219
220
        if isinstance(points, np.ndarray):
            points_numpy = points
        elif isinstance(points, BasePoints):
            points_numpy = points.tensor.numpy()
        else:
            raise NotImplementedError
        x_filt = np.abs(points_numpy[:, 0]) < radius
        y_filt = np.abs(points_numpy[:, 1]) < radius
221
        not_close = np.logical_not(np.logical_and(x_filt, y_filt))
222
        return points[not_close]
223

224
    def transform(self, results: dict) -> dict:
225
226
227
        """Call function to load multi-sweep point clouds from files.

        Args:
228
            results (dict): Result dict containing multi-sweep point cloud
229
230
231
                filenames.

        Returns:
232
            dict: The result dict containing the multi-sweep points data.
233
                Updated key and value are described below.
234

235
                - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
236
                    cloud arrays.
237
        """
zhangwenwei's avatar
zhangwenwei committed
238
        points = results['points']
239
        points.tensor[:, 4] = 0
zhangwenwei's avatar
zhangwenwei committed
240
241
        sweep_points_list = [points]
        ts = results['timestamp']
VVsssssk's avatar
VVsssssk committed
242
243
244
245
246
247
248
        if 'lidar_sweeps' not in results:
            if self.pad_empty_sweeps:
                for i in range(self.sweeps_num):
                    if self.remove_close:
                        sweep_points_list.append(self._remove_close(points))
                    else:
                        sweep_points_list.append(points)
249
        else:
VVsssssk's avatar
VVsssssk committed
250
251
            if len(results['lidar_sweeps']) <= self.sweeps_num:
                choices = np.arange(len(results['lidar_sweeps']))
252
253
254
255
            elif self.test_mode:
                choices = np.arange(self.sweeps_num)
            else:
                choices = np.random.choice(
VVsssssk's avatar
VVsssssk committed
256
257
258
                    len(results['lidar_sweeps']),
                    self.sweeps_num,
                    replace=False)
259
            for idx in choices:
VVsssssk's avatar
VVsssssk committed
260
261
262
                sweep = results['lidar_sweeps'][idx]
                points_sweep = self._load_points(
                    sweep['lidar_points']['lidar_path'])
263
264
265
                points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
                if self.remove_close:
                    points_sweep = self._remove_close(points_sweep)
VVsssssk's avatar
VVsssssk committed
266
267
                # bc-breaking: Timestamp has divided 1e6 in pkl infos.
                sweep_ts = sweep['timestamp']
268
269
270
271
                lidar2sensor = np.array(sweep['lidar_points']['lidar2sensor'])
                points_sweep[:, :
                             3] = points_sweep[:, :3] @ lidar2sensor[:3, :3]
                points_sweep[:, :3] -= lidar2sensor[:3, 3]
272
                points_sweep[:, 4] = ts - sweep_ts
273
                points_sweep = points.new_point(points_sweep)
274
275
                sweep_points_list.append(points_sweep)

276
277
        points = points.cat(sweep_points_list)
        points = points[:, self.use_dim]
zhangwenwei's avatar
zhangwenwei committed
278
279
280
281
        results['points'] = points
        return results

    def __repr__(self):
282
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
283
        return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
wuyuefeng's avatar
wuyuefeng committed
284
285


286
@TRANSFORMS.register_module()
287
class PointSegClassMapping(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
288
289
    """Map original semantic class to valid category ids.

290
291
    Required Keys:

292
293
    - seg_label_mapping (np.ndarray)
    - pts_semantic_mask (np.ndarray)
294
295
296
297
298

    Added Keys:

    - points (np.float32)

wuyuefeng's avatar
wuyuefeng committed
299
300
301
302
    Map valid classes as 0~len(valid_cat_ids)-1 and
    others as len(valid_cat_ids).
    """

303
    def transform(self, results: dict) -> dict:
304
305
306
307
308
309
        """Call function to map original semantic class to valid category ids.

        Args:
            results (dict): Result dict containing point semantic masks.

        Returns:
310
            dict: The result dict containing the mapped category ids.
311
312
313
314
                Updated key and value are described below.

                - pts_semantic_mask (np.ndarray): Mapped semantic masks.
        """
wuyuefeng's avatar
wuyuefeng committed
315
316
317
        assert 'pts_semantic_mask' in results
        pts_semantic_mask = results['pts_semantic_mask']

318
319
320
        assert 'seg_label_mapping' in results
        label_mapping = results['seg_label_mapping']
        converted_pts_sem_mask = label_mapping[pts_semantic_mask]
wuyuefeng's avatar
wuyuefeng committed
321

322
        results['pts_semantic_mask'] = converted_pts_sem_mask
ZCMax's avatar
ZCMax committed
323
324
325
326
327
328
329

        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            assert 'pts_semantic_mask' in results['eval_ann_info']
            results['eval_ann_info']['pts_semantic_mask'] = \
                converted_pts_sem_mask

wuyuefeng's avatar
wuyuefeng committed
330
331
332
        return results

    def __repr__(self):
333
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
334
335
336
337
        repr_str = self.__class__.__name__
        return repr_str


338
@TRANSFORMS.register_module()
ZCMax's avatar
ZCMax committed
339
class NormalizePointsColor(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
340
    """Normalize color of points.
wuyuefeng's avatar
wuyuefeng committed
341
342
343
344
345

    Args:
        color_mean (list[float]): Mean color of the point cloud.
    """

ZCMax's avatar
ZCMax committed
346
    def __init__(self, color_mean: List[float]) -> None:
wuyuefeng's avatar
wuyuefeng committed
347
348
        self.color_mean = color_mean

ZCMax's avatar
ZCMax committed
349
    def transform(self, input_dict: dict) -> dict:
350
351
352
353
354
355
        """Call function to normalize color of points.

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
356
            dict: The result dict containing the normalized points.
357
358
                Updated key and value are described below.

359
                - points (:obj:`BasePoints`): Points after color normalization.
360
        """
ZCMax's avatar
ZCMax committed
361
        points = input_dict['points']
362
        assert points.attribute_dims is not None and \
363
364
               'color' in points.attribute_dims.keys(), \
               'Expect points have color attribute'
365
366
        if self.color_mean is not None:
            points.color = points.color - \
367
                           points.color.new_tensor(self.color_mean)
368
        points.color = points.color / 255.0
ZCMax's avatar
ZCMax committed
369
370
        input_dict['points'] = points
        return input_dict
wuyuefeng's avatar
wuyuefeng committed
371
372

    def __repr__(self):
373
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
374
        repr_str = self.__class__.__name__
375
        repr_str += f'(color_mean={self.color_mean})'
wuyuefeng's avatar
wuyuefeng committed
376
377
378
        return repr_str


379
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
380
class LoadPointsFromFile(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
381
382
    """Load Points From File.

jshilong's avatar
jshilong committed
383
384
385
386
387
388
389
390
391
    Required Keys:

    - lidar_points (dict)

        - lidar_path (str)

    Added Keys:

    - points (np.float32)
wuyuefeng's avatar
wuyuefeng committed
392
393

    Args:
394
395
        coord_type (str): The type of coordinates of points cloud.
            Available options includes:
396

397
398
399
            - 'LIDAR': Points in LiDAR coordinates.
            - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
            - 'CAMERA': Points in camera coordinates.
400
        load_dim (int, optional): The dimension of the loaded points.
401
            Defaults to 6.
402
403
        use_dim (list[int] | int, optional): Which dimensions of the points
            to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
liyinhao's avatar
liyinhao committed
404
            or use_dim=[0, 1, 2, 3] to use the intensity dimension.
405
406
407
408
409
410
        shift_height (bool, optional): Whether to use shifted height.
            Defaults to False.
        use_color (bool, optional): Whether to use color features.
            Defaults to False.
        file_client_args (dict, optional): Config dict of file clients,
            refer to
411
            https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
liyinhao's avatar
liyinhao committed
412
            for more details. Defaults to dict(backend='disk').
wuyuefeng's avatar
wuyuefeng committed
413
414
    """

jshilong's avatar
jshilong committed
415
416
417
418
    def __init__(
        self,
        coord_type: str,
        load_dim: int = 6,
419
        use_dim: Union[int, List[int]] = [0, 1, 2],
jshilong's avatar
jshilong committed
420
421
422
423
        shift_height: bool = False,
        use_color: bool = False,
        file_client_args: dict = dict(backend='disk')
    ) -> None:
wuyuefeng's avatar
wuyuefeng committed
424
        self.shift_height = shift_height
425
        self.use_color = use_color
wuyuefeng's avatar
wuyuefeng committed
426
427
428
429
        if isinstance(use_dim, int):
            use_dim = list(range(use_dim))
        assert max(use_dim) < load_dim, \
            f'Expect all used dimensions < {load_dim}, got {use_dim}'
430
        assert coord_type in ['CAMERA', 'LIDAR', 'DEPTH']
wuyuefeng's avatar
wuyuefeng committed
431

432
        self.coord_type = coord_type
wuyuefeng's avatar
wuyuefeng committed
433
434
435
436
437
        self.load_dim = load_dim
        self.use_dim = use_dim
        self.file_client_args = file_client_args.copy()
        self.file_client = None

jshilong's avatar
jshilong committed
438
    def _load_points(self, pts_filename: str) -> np.ndarray:
439
440
441
442
443
444
445
446
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
wuyuefeng's avatar
wuyuefeng committed
447
        if self.file_client is None:
448
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
449
450
451
452
        try:
            pts_bytes = self.file_client.get(pts_filename)
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
453
            mmengine.check_file_exist(pts_filename)
wuyuefeng's avatar
wuyuefeng committed
454
455
456
457
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
458

wuyuefeng's avatar
wuyuefeng committed
459
460
        return points

jshilong's avatar
jshilong committed
461
462
    def transform(self, results: dict) -> dict:
        """Method to load points data from file.
463
464
465
466
467

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
468
            dict: The result dict containing the point clouds data.
469
470
                Added key and value are described below.

471
                - points (:obj:`BasePoints`): Point clouds data.
472
        """
jshilong's avatar
jshilong committed
473
474
        pts_file_path = results['lidar_points']['lidar_path']
        points = self._load_points(pts_file_path)
wuyuefeng's avatar
wuyuefeng committed
475
476
        points = points.reshape(-1, self.load_dim)
        points = points[:, self.use_dim]
477
        attribute_dims = None
wuyuefeng's avatar
wuyuefeng committed
478
479
480
481

        if self.shift_height:
            floor_height = np.percentile(points[:, 2], 0.99)
            height = points[:, 2] - floor_height
482
483
484
            points = np.concatenate(
                [points[:, :3],
                 np.expand_dims(height, 1), points[:, 3:]], 1)
485
486
            attribute_dims = dict(height=3)

487
488
489
490
491
492
493
494
495
496
497
        if self.use_color:
            assert len(self.use_dim) >= 6
            if attribute_dims is None:
                attribute_dims = dict()
            attribute_dims.update(
                dict(color=[
                    points.shape[1] - 3,
                    points.shape[1] - 2,
                    points.shape[1] - 1,
                ]))

498
499
500
        points_class = get_points_type(self.coord_type)
        points = points_class(
            points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
wuyuefeng's avatar
wuyuefeng committed
501
        results['points'] = points
502

wuyuefeng's avatar
wuyuefeng committed
503
504
505
        return results

    def __repr__(self):
506
        """str: Return a string that describes the module."""
liyinhao's avatar
liyinhao committed
507
        repr_str = self.__class__.__name__ + '('
508
509
510
511
512
        repr_str += f'shift_height={self.shift_height}, '
        repr_str += f'use_color={self.use_color}, '
        repr_str += f'file_client_args={self.file_client_args}, '
        repr_str += f'load_dim={self.load_dim}, '
        repr_str += f'use_dim={self.use_dim})'
wuyuefeng's avatar
wuyuefeng committed
513
514
515
        return repr_str


516
@TRANSFORMS.register_module()
517
518
519
class LoadPointsFromDict(LoadPointsFromFile):
    """Load Points From Dict."""

ChaimZhu's avatar
ChaimZhu committed
520
    def transform(self, results: dict) -> dict:
521
522
523
524
        assert 'points' in results
        return results


525
@TRANSFORMS.register_module()
wuyuefeng's avatar
wuyuefeng committed
526
527
528
529
530
531
class LoadAnnotations3D(LoadAnnotations):
    """Load Annotations3D.

    Load instance mask and semantic mask of points and
    encapsulate the items into related fields.

jshilong's avatar
jshilong committed
532
533
534
    Required Keys:

    - ann_info (dict)
535

jshilong's avatar
jshilong committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
          :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
          3D ground truth bboxes. Only when `with_bbox_3d` is True
        - gt_labels_3d (np.int64): Labels of ground truths.
          Only when `with_label_3d` is True.
        - gt_bboxes (np.float32): 2D ground truth bboxes.
          Only when `with_bbox` is True.
        - gt_labels (np.ndarray): Labels of ground truths.
          Only when `with_label` is True.
        - depths (np.ndarray): Only when
          `with_bbox_depth` is True.
        - centers_2d (np.ndarray): Only when
          `with_bbox_depth` is True.
        - attr_labels (np.ndarray): Attribute labels of instances.
          Only when `with_attr_label` is True.

    - pts_instance_mask_path (str): Path of instance mask file.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask_path (str): Path of semantic mask file.
      Only when

    Added Keys:

    - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
      :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
      3D ground truth bboxes. Only when `with_bbox_3d` is True
    - gt_labels_3d (np.int64): Labels of ground truths.
      Only when `with_label_3d` is True.
    - gt_bboxes (np.float32): 2D ground truth bboxes.
      Only when `with_bbox` is True.
    - gt_labels (np.int64): Labels of ground truths.
      Only when `with_label` is True.
    - depths (np.float32): Only when
      `with_bbox_depth` is True.
    - centers_2d (np.ndarray): Only when
      `with_bbox_depth` is True.
    - attr_labels (np.int64): Attribute labels of instances.
      Only when `with_attr_label` is True.
    - pts_instance_mask (np.int64): Instance mask of each point.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask (np.int64): Semantic mask of each point.
      Only when `with_seg_3d` is True.

wuyuefeng's avatar
wuyuefeng committed
579
580
581
582
583
    Args:
        with_bbox_3d (bool, optional): Whether to load 3D boxes.
            Defaults to True.
        with_label_3d (bool, optional): Whether to load 3D labels.
            Defaults to True.
584
585
        with_attr_label (bool, optional): Whether to load attribute label.
            Defaults to False.
wuyuefeng's avatar
wuyuefeng committed
586
587
588
589
590
591
592
593
594
595
596
597
        with_mask_3d (bool, optional): Whether to load 3D instance masks.
            for points. Defaults to False.
        with_seg_3d (bool, optional): Whether to load 3D semantic masks.
            for points. Defaults to False.
        with_bbox (bool, optional): Whether to load 2D boxes.
            Defaults to False.
        with_label (bool, optional): Whether to load 2D labels.
            Defaults to False.
        with_mask (bool, optional): Whether to load 2D instance masks.
            Defaults to False.
        with_seg (bool, optional): Whether to load 2D semantic masks.
            Defaults to False.
598
599
        with_bbox_depth (bool, optional): Whether to load 2.5D boxes.
            Defaults to False.
wuyuefeng's avatar
wuyuefeng committed
600
601
        poly2mask (bool, optional): Whether to convert polygon annotations
            to bitmasks. Defaults to True.
602
        seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
jshilong's avatar
jshilong committed
603
            Defaults to int64.
wuyuefeng's avatar
wuyuefeng committed
604
        file_client_args (dict): Config dict of file clients, refer to
605
            https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
wuyuefeng's avatar
wuyuefeng committed
606
607
608
            for more details.
    """

jshilong's avatar
jshilong committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    def __init__(
        self,
        with_bbox_3d: bool = True,
        with_label_3d: bool = True,
        with_attr_label: bool = False,
        with_mask_3d: bool = False,
        with_seg_3d: bool = False,
        with_bbox: bool = False,
        with_label: bool = False,
        with_mask: bool = False,
        with_seg: bool = False,
        with_bbox_depth: bool = False,
        poly2mask: bool = True,
        seg_3d_dtype: np.dtype = np.int64,
        file_client_args: dict = dict(backend='disk')
    ) -> None:
wuyuefeng's avatar
wuyuefeng committed
625
        super().__init__(
jshilong's avatar
jshilong committed
626
627
628
629
630
            with_bbox=with_bbox,
            with_label=with_label,
            with_mask=with_mask,
            with_seg=with_seg,
            poly2mask=poly2mask,
wuyuefeng's avatar
wuyuefeng committed
631
632
            file_client_args=file_client_args)
        self.with_bbox_3d = with_bbox_3d
633
        self.with_bbox_depth = with_bbox_depth
wuyuefeng's avatar
wuyuefeng committed
634
        self.with_label_3d = with_label_3d
635
        self.with_attr_label = with_attr_label
wuyuefeng's avatar
wuyuefeng committed
636
637
        self.with_mask_3d = with_mask_3d
        self.with_seg_3d = with_seg_3d
638
        self.seg_3d_dtype = seg_3d_dtype
wuyuefeng's avatar
wuyuefeng committed
639

jshilong's avatar
jshilong committed
640
641
642
    def _load_bboxes_3d(self, results: dict) -> dict:
        """Private function to move the 3D bounding box annotation from
        `ann_info` field to the root of `results`.
643
644
645
646
647
648
649

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box annotations.
        """
jshilong's avatar
jshilong committed
650

wuyuefeng's avatar
wuyuefeng committed
651
652
653
        results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
        return results

jshilong's avatar
jshilong committed
654
    def _load_bboxes_depth(self, results: dict) -> dict:
655
656
657
658
659
660
661
662
        """Private function to load 2.5D bounding box annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 2.5D bounding box annotations.
        """
jshilong's avatar
jshilong committed
663

664
        results['depths'] = results['ann_info']['depths']
jshilong's avatar
jshilong committed
665
        results['centers_2d'] = results['ann_info']['centers_2d']
666
667
        return results

jshilong's avatar
jshilong committed
668
    def _load_labels_3d(self, results: dict) -> dict:
669
670
671
672
673
674
675
676
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
jshilong's avatar
jshilong committed
677

wuyuefeng's avatar
wuyuefeng committed
678
679
680
        results['gt_labels_3d'] = results['ann_info']['gt_labels_3d']
        return results

jshilong's avatar
jshilong committed
681
    def _load_attr_labels(self, results: dict) -> dict:
682
683
684
685
686
687
688
689
690
691
692
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
        results['attr_labels'] = results['ann_info']['attr_labels']
        return results

jshilong's avatar
jshilong committed
693
    def _load_masks_3d(self, results: dict) -> dict:
694
695
696
697
698
699
700
701
        """Private function to load 3D mask annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D mask annotations.
        """
jshilong's avatar
jshilong committed
702
        pts_instance_mask_path = results['pts_instance_mask_path']
wuyuefeng's avatar
wuyuefeng committed
703
704

        if self.file_client is None:
705
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
706
707
        try:
            mask_bytes = self.file_client.get(pts_instance_mask_path)
708
            pts_instance_mask = np.frombuffer(mask_bytes, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
709
        except ConnectionError:
710
            mmengine.check_file_exist(pts_instance_mask_path)
wuyuefeng's avatar
wuyuefeng committed
711
            pts_instance_mask = np.fromfile(
WRH's avatar
WRH committed
712
                pts_instance_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
713
714

        results['pts_instance_mask'] = pts_instance_mask
jshilong's avatar
jshilong committed
715
716
717
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
wuyuefeng's avatar
wuyuefeng committed
718
719
        return results

jshilong's avatar
jshilong committed
720
    def _load_semantic_seg_3d(self, results: dict) -> dict:
721
722
723
724
725
726
727
728
        """Private function to load 3D semantic segmentation annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing the semantic segmentation annotations.
        """
jshilong's avatar
jshilong committed
729
        pts_semantic_mask_path = results['pts_semantic_mask_path']
wuyuefeng's avatar
wuyuefeng committed
730
731

        if self.file_client is None:
732
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
733
734
735
        try:
            mask_bytes = self.file_client.get(pts_semantic_mask_path)
            # add .copy() to fix read-only bug
736
737
            pts_semantic_mask = np.frombuffer(
                mask_bytes, dtype=self.seg_3d_dtype).copy()
wuyuefeng's avatar
wuyuefeng committed
738
        except ConnectionError:
739
            mmengine.check_file_exist(pts_semantic_mask_path)
wuyuefeng's avatar
wuyuefeng committed
740
            pts_semantic_mask = np.fromfile(
WRH's avatar
WRH committed
741
                pts_semantic_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
742
743

        results['pts_semantic_mask'] = pts_semantic_mask
jshilong's avatar
jshilong committed
744
745
746
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
wuyuefeng's avatar
wuyuefeng committed
747
748
        return results

zhangshilong's avatar
zhangshilong committed
749
750
751
752
753
754
755
756
757
758
759
760
    def _load_bboxes(self, results: dict) -> None:
        """Private function to load bounding box annotations.

        The only difference is it remove the proceess for
        `ignore_flag`

        Args:
            results (dict): Result dict from :obj:``mmcv.BaseDataset``.
        Returns:
            dict: The dict contains loaded bounding box annotations.
        """

761
        results['gt_bboxes'] = results['ann_info']['gt_bboxes']
zhangshilong's avatar
zhangshilong committed
762
763
764
765
766
767
768
769
770
771

    def _load_labels(self, results: dict) -> None:
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.

        Returns:
            dict: The dict contains loaded label annotations.
        """
772
        results['gt_bboxes_labels'] = results['ann_info']['gt_bboxes_labels']
zhangshilong's avatar
zhangshilong committed
773

jshilong's avatar
jshilong committed
774
775
    def transform(self, results: dict) -> dict:
        """Function to load multiple types annotations.
776
777
778
779
780
781

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box, label, mask and
jshilong's avatar
jshilong committed
782
            semantic segmentation annotations.
783
        """
jshilong's avatar
jshilong committed
784
        results = super().transform(results)
wuyuefeng's avatar
wuyuefeng committed
785
786
        if self.with_bbox_3d:
            results = self._load_bboxes_3d(results)
787
788
        if self.with_bbox_depth:
            results = self._load_bboxes_depth(results)
wuyuefeng's avatar
wuyuefeng committed
789
790
        if self.with_label_3d:
            results = self._load_labels_3d(results)
791
792
        if self.with_attr_label:
            results = self._load_attr_labels(results)
wuyuefeng's avatar
wuyuefeng committed
793
794
795
796
797
798
799
800
        if self.with_mask_3d:
            results = self._load_masks_3d(results)
        if self.with_seg_3d:
            results = self._load_semantic_seg_3d(results)

        return results

    def __repr__(self):
801
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
802
803
        indent_str = '    '
        repr_str = self.__class__.__name__ + '(\n'
liyinhao's avatar
liyinhao committed
804
805
        repr_str += f'{indent_str}with_bbox_3d={self.with_bbox_3d}, '
        repr_str += f'{indent_str}with_label_3d={self.with_label_3d}, '
806
        repr_str += f'{indent_str}with_attr_label={self.with_attr_label}, '
liyinhao's avatar
liyinhao committed
807
808
809
810
811
812
        repr_str += f'{indent_str}with_mask_3d={self.with_mask_3d}, '
        repr_str += f'{indent_str}with_seg_3d={self.with_seg_3d}, '
        repr_str += f'{indent_str}with_bbox={self.with_bbox}, '
        repr_str += f'{indent_str}with_label={self.with_label}, '
        repr_str += f'{indent_str}with_mask={self.with_mask}, '
        repr_str += f'{indent_str}with_seg={self.with_seg}, '
813
        repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
wuyuefeng's avatar
wuyuefeng committed
814
815
        repr_str += f'{indent_str}poly2mask={self.poly2mask})'
        return repr_str