loading.py 50.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import copy
from typing import List, Optional, Union
4

zhangwenwei's avatar
zhangwenwei committed
5
import mmcv
6
import mmengine
zhangwenwei's avatar
zhangwenwei committed
7
import numpy as np
8
from mmcv.transforms import LoadImageFromFile
9
from mmcv.transforms.base import BaseTransform
10
from mmdet.datasets.transforms import LoadAnnotations
11
from mmengine.fileio import get
zhangwenwei's avatar
zhangwenwei committed
12

13
from mmdet3d.registry import TRANSFORMS
14
from mmdet3d.structures.bbox_3d import get_box_type
zhangshilong's avatar
zhangshilong committed
15
from mmdet3d.structures.points import BasePoints, get_points_type
zhangwenwei's avatar
zhangwenwei committed
16
17


18
@TRANSFORMS.register_module()
19
class LoadMultiViewImageFromFiles(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
20
    """Load multi channel images from a list of separate channel files.
zhangwenwei's avatar
zhangwenwei committed
21

liyinhao's avatar
liyinhao committed
22
23
24
    Expects results['img_filename'] to be a list of filenames.

    Args:
25
        to_float32 (bool): Whether to convert the img to float32.
liyinhao's avatar
liyinhao committed
26
            Defaults to False.
27
        color_type (str): Color type of the file. Defaults to 'unchanged'.
28
29
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
30
31
32
33
34
        num_views (int): Number of view in a frame. Defaults to 5.
        num_ref_frames (int): Number of frame in loading. Defaults to -1.
        test_mode (bool): Whether is test mode in loading. Defaults to False.
        set_default_scale (bool): Whether to set default scale.
            Defaults to True.
zhangwenwei's avatar
zhangwenwei committed
35
    """
zhangwenwei's avatar
zhangwenwei committed
36

37
38
    def __init__(self,
                 to_float32: bool = False,
39
                 color_type: str = 'unchanged',
40
                 backend_args: Optional[dict] = None,
41
42
43
44
                 num_views: int = 5,
                 num_ref_frames: int = -1,
                 test_mode: bool = False,
                 set_default_scale: bool = True) -> None:
zhangwenwei's avatar
zhangwenwei committed
45
46
        self.to_float32 = to_float32
        self.color_type = color_type
47
        self.backend_args = backend_args
48
49
50
51
52
53
54
55
56
        self.num_views = num_views
        # num_ref_frames is used for multi-sweep loading
        self.num_ref_frames = num_ref_frames
        # when test_mode=False, we randomly select previous frames
        # otherwise, select the earliest one
        self.test_mode = test_mode
        self.set_default_scale = set_default_scale

    def transform(self, results: dict) -> Optional[dict]:
57
58
59
60
61
62
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
63
            dict: The result dict containing the multi-view image data.
64
            Added keys and values are described below.
65
66
67
68
69
70
71
72
73

                - filename (str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        # TODO: consider split the multi-sweep part out of this pipeline
        # Derive the mask and transform for loading of multi-sweep data
        if self.num_ref_frames > 0:
            # init choice with the current frame
            init_choice = np.array([0], dtype=np.int64)
            num_frames = len(results['img_filename']) // self.num_views - 1
            if num_frames == 0:  # no previous frame, then copy cur frames
                choices = np.random.choice(
                    1, self.num_ref_frames, replace=True)
            elif num_frames >= self.num_ref_frames:
                # NOTE: suppose the info is saved following the order
                # from latest to earlier frames
                if self.test_mode:
                    choices = np.arange(num_frames - self.num_ref_frames,
                                        num_frames) + 1
                # NOTE: +1 is for selecting previous frames
                else:
                    choices = np.random.choice(
                        num_frames, self.num_ref_frames, replace=False) + 1
            elif num_frames > 0 and num_frames < self.num_ref_frames:
                if self.test_mode:
                    base_choices = np.arange(num_frames) + 1
                    random_choices = np.random.choice(
                        num_frames,
                        self.num_ref_frames - num_frames,
                        replace=True) + 1
                    choices = np.concatenate([base_choices, random_choices])
                else:
                    choices = np.random.choice(
                        num_frames, self.num_ref_frames, replace=True) + 1
            else:
                raise NotImplementedError
            choices = np.concatenate([init_choice, choices])
            select_filename = []
            for choice in choices:
                select_filename += results['img_filename'][choice *
                                                           self.num_views:
                                                           (choice + 1) *
                                                           self.num_views]
            results['img_filename'] = select_filename
            for key in ['cam2img', 'lidar2cam']:
                if key in results:
                    select_results = []
                    for choice in choices:
                        select_results += results[key][choice *
                                                       self.num_views:(choice +
                                                                       1) *
                                                       self.num_views]
                    results[key] = select_results
            for key in ['ego2global']:
                if key in results:
                    select_results = []
                    for choice in choices:
                        select_results += [results[key][choice]]
                    results[key] = select_results
            # Transform lidar2cam to
            # [cur_lidar]2[prev_img] and [cur_lidar]2[prev_cam]
            for key in ['lidar2cam']:
                if key in results:
                    # only change matrices of previous frames
                    for choice_idx in range(1, len(choices)):
                        pad_prev_ego2global = np.eye(4)
                        prev_ego2global = results['ego2global'][choice_idx]
                        pad_prev_ego2global[:prev_ego2global.
                                            shape[0], :prev_ego2global.
                                            shape[1]] = prev_ego2global
                        pad_cur_ego2global = np.eye(4)
                        cur_ego2global = results['ego2global'][0]
                        pad_cur_ego2global[:cur_ego2global.
                                           shape[0], :cur_ego2global.
                                           shape[1]] = cur_ego2global
                        cur2prev = np.linalg.inv(pad_prev_ego2global).dot(
                            pad_cur_ego2global)
                        for result_idx in range(choice_idx * self.num_views,
                                                (choice_idx + 1) *
                                                self.num_views):
                            results[key][result_idx] = \
                                results[key][result_idx].dot(cur2prev)
        # Support multi-view images with different shapes
        # TODO: record the origin shape and padded shape
        filename, cam2img, lidar2cam = [], [], []
        for _, cam_item in results['images'].items():
            filename.append(cam_item['img_path'])
            cam2img.append(cam_item['cam2img'])
            lidar2cam.append(cam_item['lidar2cam'])
        results['filename'] = filename
        results['cam2img'] = cam2img
        results['lidar2cam'] = lidar2cam

        results['ori_cam2img'] = copy.deepcopy(results['cam2img'])

165
        # img is of shape (h, w, c, num_views)
166
        # h and w can be different for different views
167
168
169
        img_bytes = [
            get(name, backend_args=self.backend_args) for name in filename
        ]
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        imgs = [
            mmcv.imfrombytes(img_byte, flag=self.color_type)
            for img_byte in img_bytes
        ]
        # handle the image with different shape
        img_shapes = np.stack([img.shape for img in imgs], axis=0)
        img_shape_max = np.max(img_shapes, axis=0)
        img_shape_min = np.min(img_shapes, axis=0)
        assert img_shape_min[-1] == img_shape_max[-1]
        if not np.all(img_shape_max == img_shape_min):
            pad_shape = img_shape_max[:2]
        else:
            pad_shape = None
        if pad_shape is not None:
            imgs = [
                mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs
            ]
        img = np.stack(imgs, axis=-1)
zhangwenwei's avatar
zhangwenwei committed
188
189
        if self.to_float32:
            img = img.astype(np.float32)
190

zhangwenwei's avatar
zhangwenwei committed
191
        results['filename'] = filename
192
        # unravel to list, see `DefaultFormatBundle` in formating.py
193
194
        # which will transpose each image separately and then stack into array
        results['img'] = [img[..., i] for i in range(img.shape[-1])]
195
196
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
zhangwenwei's avatar
zhangwenwei committed
197
        # Set initial values for default meta_keys
198
        results['pad_shape'] = img.shape[:2]
199
200
        if self.set_default_scale:
            results['scale_factor'] = 1.0
zhangwenwei's avatar
zhangwenwei committed
201
202
203
204
205
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
206
207
        results['num_views'] = self.num_views
        results['num_ref_frames'] = self.num_ref_frames
zhangwenwei's avatar
zhangwenwei committed
208
209
        return results

210
    def __repr__(self) -> str:
211
        """str: Return a string that describes the module."""
212
213
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32}, '
214
215
216
217
        repr_str += f"color_type='{self.color_type}', "
        repr_str += f'num_views={self.num_views}, '
        repr_str += f'num_ref_frames={self.num_ref_frames}, '
        repr_str += f'test_mode={self.test_mode})'
218
        return repr_str
zhangwenwei's avatar
zhangwenwei committed
219
220


221
@TRANSFORMS.register_module()
222
223
224
225
226
class LoadImageFromFileMono3D(LoadImageFromFile):
    """Load an image from file in monocular 3D object detection. Compared to 2D
    detection, additional camera parameters need to be loaded.

    Args:
227
        kwargs (dict): Arguments are the same as those in
228
229
230
            :class:`LoadImageFromFile`.
    """

ZCMax's avatar
ZCMax committed
231
    def transform(self, results: dict) -> dict:
232
233
234
235
236
237
238
239
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmdet.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """
ZCMax's avatar
ZCMax committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        # TODO: load different camera image from data info,
        # for kitti dataset, we load 'CAM2' image.
        # for nuscenes dataset, we load 'CAM_FRONT' image.

        if 'CAM2' in results['images']:
            filename = results['images']['CAM2']['img_path']
            results['cam2img'] = results['images']['CAM2']['cam2img']
        elif len(list(results['images'].keys())) == 1:
            camera_type = list(results['images'].keys())[0]
            filename = results['images'][camera_type]['img_path']
            results['cam2img'] = results['images'][camera_type]['cam2img']
        else:
            raise NotImplementedError(
                'Currently we only support load image from kitti and'
                'nuscenes datasets')

256
        try:
257
            img_bytes = get(filename, backend_args=self.backend_args)
258
259
260
261
262
263
264
            img = mmcv.imfrombytes(
                img_bytes, flag=self.color_type, backend=self.imdecode_backend)
        except Exception as e:
            if self.ignore_empty:
                return None
            else:
                raise e
ZCMax's avatar
ZCMax committed
265
266
267
268
269
270
271
        if self.to_float32:
            img = img.astype(np.float32)

        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]

272
273
274
        return results


275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
    """Load an image from ``results['img']``.
    Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
    :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
    from webcam.
    Required Keys:
    - img
    Modified Keys:
    - img
    - img_path
    - img_shape
    - ori_shape
    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
    """

    def transform(self, results: dict) -> dict:
        """Transform function to add image meta information.

        Args:
            results (dict): Result dict with Webcam read image in
                ``results['img']``.
        Returns:
            dict: The dict contains loaded image and meta information.
        """

        img = results['img']
        if self.to_float32:
            img = img.astype(np.float32)

        results['img_path'] = None
        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results


315
@TRANSFORMS.register_module()
VVsssssk's avatar
VVsssssk committed
316
class LoadPointsFromMultiSweeps(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
317
    """Load points from multiple sweeps.
zhangwenwei's avatar
zhangwenwei committed
318

zhangwenwei's avatar
zhangwenwei committed
319
320
321
    This is usually used for nuScenes dataset to utilize previous sweeps.

    Args:
322
323
324
        sweeps_num (int): Number of sweeps. Defaults to 10.
        load_dim (int): Dimension number of the loaded points. Defaults to 5.
        use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
325
326
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
327
        pad_empty_sweeps (bool): Whether to repeat keyframe when
328
            sweeps is empty. Defaults to False.
329
330
331
        remove_close (bool): Whether to remove close points. Defaults to False.
        test_mode (bool): If `test_mode=True`, it will not randomly sample
            sweeps but select the nearest N frames. Defaults to False.
zhangwenwei's avatar
zhangwenwei committed
332
333
    """

334
335
336
337
    def __init__(self,
                 sweeps_num: int = 10,
                 load_dim: int = 5,
                 use_dim: List[int] = [0, 1, 2, 4],
338
                 backend_args: Optional[dict] = None,
339
340
341
                 pad_empty_sweeps: bool = False,
                 remove_close: bool = False,
                 test_mode: bool = False) -> None:
zhangwenwei's avatar
zhangwenwei committed
342
        self.load_dim = load_dim
zhangwenwei's avatar
zhangwenwei committed
343
        self.sweeps_num = sweeps_num
344
345
346
347
        if isinstance(use_dim, int):
            use_dim = list(range(use_dim))
        assert max(use_dim) < load_dim, \
            f'Expect all used dimensions < {load_dim}, got {use_dim}'
348
        self.use_dim = use_dim
349
        self.backend_args = backend_args
350
351
352
        self.pad_empty_sweeps = pad_empty_sweeps
        self.remove_close = remove_close
        self.test_mode = test_mode
zhangwenwei's avatar
zhangwenwei committed
353

354
    def _load_points(self, pts_filename: str) -> np.ndarray:
355
356
357
358
359
360
361
362
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
zhangwenwei's avatar
zhangwenwei committed
363
        try:
364
            pts_bytes = get(pts_filename, backend_args=self.backend_args)
zhangwenwei's avatar
zhangwenwei committed
365
366
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
367
            mmengine.check_file_exist(pts_filename)
zhangwenwei's avatar
zhangwenwei committed
368
369
370
371
372
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
        return points
zhangwenwei's avatar
zhangwenwei committed
373

374
375
376
    def _remove_close(self,
                      points: Union[np.ndarray, BasePoints],
                      radius: float = 1.0) -> Union[np.ndarray, BasePoints]:
377
        """Remove point too close within a certain radius from origin.
378
379

        Args:
380
            points (np.ndarray | :obj:`BasePoints`): Sweep points.
381
            radius (float): Radius below which points are removed.
382
383
384
                Defaults to 1.0.

        Returns:
385
            np.ndarray | :obj:`BasePoints`: Points after removing.
386
        """
387
388
389
390
391
392
393
394
        if isinstance(points, np.ndarray):
            points_numpy = points
        elif isinstance(points, BasePoints):
            points_numpy = points.tensor.numpy()
        else:
            raise NotImplementedError
        x_filt = np.abs(points_numpy[:, 0]) < radius
        y_filt = np.abs(points_numpy[:, 1]) < radius
395
        not_close = np.logical_not(np.logical_and(x_filt, y_filt))
396
        return points[not_close]
397

398
    def transform(self, results: dict) -> dict:
399
400
401
        """Call function to load multi-sweep point clouds from files.

        Args:
402
            results (dict): Result dict containing multi-sweep point cloud
403
404
405
                filenames.

        Returns:
406
            dict: The result dict containing the multi-sweep points data.
407
            Updated key and value are described below.
408

409
                - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
410
                  cloud arrays.
411
        """
zhangwenwei's avatar
zhangwenwei committed
412
        points = results['points']
413
        points.tensor[:, 4] = 0
zhangwenwei's avatar
zhangwenwei committed
414
415
        sweep_points_list = [points]
        ts = results['timestamp']
VVsssssk's avatar
VVsssssk committed
416
417
418
419
420
421
422
        if 'lidar_sweeps' not in results:
            if self.pad_empty_sweeps:
                for i in range(self.sweeps_num):
                    if self.remove_close:
                        sweep_points_list.append(self._remove_close(points))
                    else:
                        sweep_points_list.append(points)
423
        else:
VVsssssk's avatar
VVsssssk committed
424
425
            if len(results['lidar_sweeps']) <= self.sweeps_num:
                choices = np.arange(len(results['lidar_sweeps']))
426
427
428
429
            elif self.test_mode:
                choices = np.arange(self.sweeps_num)
            else:
                choices = np.random.choice(
VVsssssk's avatar
VVsssssk committed
430
431
432
                    len(results['lidar_sweeps']),
                    self.sweeps_num,
                    replace=False)
433
            for idx in choices:
VVsssssk's avatar
VVsssssk committed
434
435
436
                sweep = results['lidar_sweeps'][idx]
                points_sweep = self._load_points(
                    sweep['lidar_points']['lidar_path'])
437
438
439
                points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
                if self.remove_close:
                    points_sweep = self._remove_close(points_sweep)
VVsssssk's avatar
VVsssssk committed
440
441
                # bc-breaking: Timestamp has divided 1e6 in pkl infos.
                sweep_ts = sweep['timestamp']
442
443
444
445
                lidar2sensor = np.array(sweep['lidar_points']['lidar2sensor'])
                points_sweep[:, :
                             3] = points_sweep[:, :3] @ lidar2sensor[:3, :3]
                points_sweep[:, :3] -= lidar2sensor[:3, 3]
446
                points_sweep[:, 4] = ts - sweep_ts
447
                points_sweep = points.new_point(points_sweep)
448
449
                sweep_points_list.append(points_sweep)

450
451
        points = points.cat(sweep_points_list)
        points = points[:, self.use_dim]
zhangwenwei's avatar
zhangwenwei committed
452
453
454
        results['points'] = points
        return results

455
    def __repr__(self) -> str:
456
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
457
        return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
wuyuefeng's avatar
wuyuefeng committed
458
459


460
@TRANSFORMS.register_module()
461
class PointSegClassMapping(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
462
463
    """Map original semantic class to valid category ids.

464
465
    Required Keys:

466
467
    - seg_label_mapping (np.ndarray)
    - pts_semantic_mask (np.ndarray)
468
469
470
471
472

    Added Keys:

    - points (np.float32)

wuyuefeng's avatar
wuyuefeng committed
473
474
475
476
    Map valid classes as 0~len(valid_cat_ids)-1 and
    others as len(valid_cat_ids).
    """

477
    def transform(self, results: dict) -> dict:
478
479
480
481
482
483
        """Call function to map original semantic class to valid category ids.

        Args:
            results (dict): Result dict containing point semantic masks.

        Returns:
484
            dict: The result dict containing the mapped category ids.
485
            Updated key and value are described below.
486
487
488

                - pts_semantic_mask (np.ndarray): Mapped semantic masks.
        """
wuyuefeng's avatar
wuyuefeng committed
489
490
491
        assert 'pts_semantic_mask' in results
        pts_semantic_mask = results['pts_semantic_mask']

492
493
494
        assert 'seg_label_mapping' in results
        label_mapping = results['seg_label_mapping']
        converted_pts_sem_mask = label_mapping[pts_semantic_mask]
wuyuefeng's avatar
wuyuefeng committed
495

496
        results['pts_semantic_mask'] = converted_pts_sem_mask
ZCMax's avatar
ZCMax committed
497
498
499
500
501
502
503

        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            assert 'pts_semantic_mask' in results['eval_ann_info']
            results['eval_ann_info']['pts_semantic_mask'] = \
                converted_pts_sem_mask

wuyuefeng's avatar
wuyuefeng committed
504
505
        return results

506
    def __repr__(self) -> str:
507
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
508
509
510
511
        repr_str = self.__class__.__name__
        return repr_str


512
@TRANSFORMS.register_module()
ZCMax's avatar
ZCMax committed
513
class NormalizePointsColor(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
514
    """Normalize color of points.
wuyuefeng's avatar
wuyuefeng committed
515
516
517
518
519

    Args:
        color_mean (list[float]): Mean color of the point cloud.
    """

ZCMax's avatar
ZCMax committed
520
    def __init__(self, color_mean: List[float]) -> None:
wuyuefeng's avatar
wuyuefeng committed
521
522
        self.color_mean = color_mean

ZCMax's avatar
ZCMax committed
523
    def transform(self, input_dict: dict) -> dict:
524
525
526
527
528
529
        """Call function to normalize color of points.

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
530
            dict: The result dict containing the normalized points.
531
            Updated key and value are described below.
532

533
                - points (:obj:`BasePoints`): Points after color normalization.
534
        """
ZCMax's avatar
ZCMax committed
535
        points = input_dict['points']
536
        assert points.attribute_dims is not None and \
537
538
               'color' in points.attribute_dims.keys(), \
               'Expect points have color attribute'
539
540
        if self.color_mean is not None:
            points.color = points.color - \
541
                           points.color.new_tensor(self.color_mean)
542
        points.color = points.color / 255.0
ZCMax's avatar
ZCMax committed
543
544
        input_dict['points'] = points
        return input_dict
wuyuefeng's avatar
wuyuefeng committed
545

546
    def __repr__(self) -> str:
547
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
548
        repr_str = self.__class__.__name__
549
        repr_str += f'(color_mean={self.color_mean})'
wuyuefeng's avatar
wuyuefeng committed
550
551
552
        return repr_str


553
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
554
class LoadPointsFromFile(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
555
556
    """Load Points From File.

jshilong's avatar
jshilong committed
557
558
559
560
561
562
563
564
565
    Required Keys:

    - lidar_points (dict)

        - lidar_path (str)

    Added Keys:

    - points (np.float32)
wuyuefeng's avatar
wuyuefeng committed
566
567

    Args:
568
569
        coord_type (str): The type of coordinates of points cloud.
            Available options includes:
570

571
572
573
            - 'LIDAR': Points in LiDAR coordinates.
            - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
            - 'CAMERA': Points in camera coordinates.
574
575
576
        load_dim (int): The dimension of the loaded points. Defaults to 6.
        use_dim (list[int] | int): Which dimensions of the points to use.
            Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
liyinhao's avatar
liyinhao committed
577
            or use_dim=[0, 1, 2, 3] to use the intensity dimension.
578
579
        shift_height (bool): Whether to use shifted height. Defaults to False.
        use_color (bool): Whether to use color features. Defaults to False.
580
581
        norm_intensity (bool): Whether to normlize the intensity. Defaults to
            False.
582
583
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
584
585
    """

586
587
588
589
590
591
592
593
    def __init__(self,
                 coord_type: str,
                 load_dim: int = 6,
                 use_dim: Union[int, List[int]] = [0, 1, 2],
                 shift_height: bool = False,
                 use_color: bool = False,
                 norm_intensity: bool = False,
                 backend_args: Optional[dict] = None) -> None:
wuyuefeng's avatar
wuyuefeng committed
594
        self.shift_height = shift_height
595
        self.use_color = use_color
wuyuefeng's avatar
wuyuefeng committed
596
597
598
599
        if isinstance(use_dim, int):
            use_dim = list(range(use_dim))
        assert max(use_dim) < load_dim, \
            f'Expect all used dimensions < {load_dim}, got {use_dim}'
600
        assert coord_type in ['CAMERA', 'LIDAR', 'DEPTH']
wuyuefeng's avatar
wuyuefeng committed
601

602
        self.coord_type = coord_type
wuyuefeng's avatar
wuyuefeng committed
603
604
        self.load_dim = load_dim
        self.use_dim = use_dim
605
        self.norm_intensity = norm_intensity
606
        self.backend_args = backend_args
wuyuefeng's avatar
wuyuefeng committed
607

jshilong's avatar
jshilong committed
608
    def _load_points(self, pts_filename: str) -> np.ndarray:
609
610
611
612
613
614
615
616
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
wuyuefeng's avatar
wuyuefeng committed
617
        try:
618
            pts_bytes = get(pts_filename, backend_args=self.backend_args)
wuyuefeng's avatar
wuyuefeng committed
619
620
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
621
            mmengine.check_file_exist(pts_filename)
wuyuefeng's avatar
wuyuefeng committed
622
623
624
625
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
626

wuyuefeng's avatar
wuyuefeng committed
627
628
        return points

jshilong's avatar
jshilong committed
629
630
    def transform(self, results: dict) -> dict:
        """Method to load points data from file.
631
632
633
634
635

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
636
            dict: The result dict containing the point clouds data.
637
            Added key and value are described below.
638

639
                - points (:obj:`BasePoints`): Point clouds data.
640
        """
jshilong's avatar
jshilong committed
641
642
        pts_file_path = results['lidar_points']['lidar_path']
        points = self._load_points(pts_file_path)
wuyuefeng's avatar
wuyuefeng committed
643
644
        points = points.reshape(-1, self.load_dim)
        points = points[:, self.use_dim]
645
646
647
648
        if self.norm_intensity:
            assert len(self.use_dim) >= 4, \
                f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}'  # noqa: E501
            points[:, 3] = np.tanh(points[:, 3])
649
        attribute_dims = None
wuyuefeng's avatar
wuyuefeng committed
650
651
652
653

        if self.shift_height:
            floor_height = np.percentile(points[:, 2], 0.99)
            height = points[:, 2] - floor_height
654
655
656
            points = np.concatenate(
                [points[:, :3],
                 np.expand_dims(height, 1), points[:, 3:]], 1)
657
658
            attribute_dims = dict(height=3)

659
660
661
662
663
664
665
666
667
668
669
        if self.use_color:
            assert len(self.use_dim) >= 6
            if attribute_dims is None:
                attribute_dims = dict()
            attribute_dims.update(
                dict(color=[
                    points.shape[1] - 3,
                    points.shape[1] - 2,
                    points.shape[1] - 1,
                ]))

670
671
672
        points_class = get_points_type(self.coord_type)
        points = points_class(
            points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
wuyuefeng's avatar
wuyuefeng committed
673
        results['points'] = points
674

wuyuefeng's avatar
wuyuefeng committed
675
676
        return results

677
    def __repr__(self) -> str:
678
        """str: Return a string that describes the module."""
liyinhao's avatar
liyinhao committed
679
        repr_str = self.__class__.__name__ + '('
680
681
        repr_str += f'shift_height={self.shift_height}, '
        repr_str += f'use_color={self.use_color}, '
682
        repr_str += f'backend_args={self.backend_args}, '
683
684
        repr_str += f'load_dim={self.load_dim}, '
        repr_str += f'use_dim={self.use_dim})'
wuyuefeng's avatar
wuyuefeng committed
685
686
687
        return repr_str


688
@TRANSFORMS.register_module()
689
690
691
class LoadPointsFromDict(LoadPointsFromFile):
    """Load Points From Dict."""

ChaimZhu's avatar
ChaimZhu committed
692
    def transform(self, results: dict) -> dict:
693
694
695
696
697
698
699
700
701
702
        """Convert the type of points from ndarray to corresponding
        `point_class`.

        Args:
            results (dict): input result. The value of key `points` is a
                numpy array.

        Returns:
            dict: The processed results.
        """
703
        assert 'points' in results
704
        points = results['points']
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

        if self.norm_intensity:
            assert len(self.use_dim) >= 4, \
                f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}'  # noqa: E501
            points[:, 3] = np.tanh(points[:, 3])
        attribute_dims = None

        if self.shift_height:
            floor_height = np.percentile(points[:, 2], 0.99)
            height = points[:, 2] - floor_height
            points = np.concatenate(
                [points[:, :3],
                 np.expand_dims(height, 1), points[:, 3:]], 1)
            attribute_dims = dict(height=3)

        if self.use_color:
            assert len(self.use_dim) >= 6
            if attribute_dims is None:
                attribute_dims = dict()
            attribute_dims.update(
                dict(color=[
                    points.shape[1] - 3,
                    points.shape[1] - 2,
                    points.shape[1] - 1,
                ]))

        points_class = get_points_type(self.coord_type)
        points = points_class(
            points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
        results['points'] = points
735
736
737
        return results


738
@TRANSFORMS.register_module()
wuyuefeng's avatar
wuyuefeng committed
739
740
741
742
743
744
class LoadAnnotations3D(LoadAnnotations):
    """Load Annotations3D.

    Load instance mask and semantic mask of points and
    encapsulate the items into related fields.

jshilong's avatar
jshilong committed
745
746
747
    Required Keys:

    - ann_info (dict)
748

jshilong's avatar
jshilong committed
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
          :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
          3D ground truth bboxes. Only when `with_bbox_3d` is True
        - gt_labels_3d (np.int64): Labels of ground truths.
          Only when `with_label_3d` is True.
        - gt_bboxes (np.float32): 2D ground truth bboxes.
          Only when `with_bbox` is True.
        - gt_labels (np.ndarray): Labels of ground truths.
          Only when `with_label` is True.
        - depths (np.ndarray): Only when
          `with_bbox_depth` is True.
        - centers_2d (np.ndarray): Only when
          `with_bbox_depth` is True.
        - attr_labels (np.ndarray): Attribute labels of instances.
          Only when `with_attr_label` is True.

    - pts_instance_mask_path (str): Path of instance mask file.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask_path (str): Path of semantic mask file.
768
      Only when `with_seg_3d` is True.
769
770
    - pts_panoptic_mask_path (str): Path of panoptic mask file.
      Only when both `with_panoptic_3d` is True.
jshilong's avatar
jshilong committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793

    Added Keys:

    - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
      :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
      3D ground truth bboxes. Only when `with_bbox_3d` is True
    - gt_labels_3d (np.int64): Labels of ground truths.
      Only when `with_label_3d` is True.
    - gt_bboxes (np.float32): 2D ground truth bboxes.
      Only when `with_bbox` is True.
    - gt_labels (np.int64): Labels of ground truths.
      Only when `with_label` is True.
    - depths (np.float32): Only when
      `with_bbox_depth` is True.
    - centers_2d (np.ndarray): Only when
      `with_bbox_depth` is True.
    - attr_labels (np.int64): Attribute labels of instances.
      Only when `with_attr_label` is True.
    - pts_instance_mask (np.int64): Instance mask of each point.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask (np.int64): Semantic mask of each point.
      Only when `with_seg_3d` is True.

wuyuefeng's avatar
wuyuefeng committed
794
    Args:
795
796
797
        with_bbox_3d (bool): Whether to load 3D boxes. Defaults to True.
        with_label_3d (bool): Whether to load 3D labels. Defaults to True.
        with_attr_label (bool): Whether to load attribute label.
wuyuefeng's avatar
wuyuefeng committed
798
            Defaults to False.
799
        with_mask_3d (bool): Whether to load 3D instance masks for points.
wuyuefeng's avatar
wuyuefeng committed
800
            Defaults to False.
801
        with_seg_3d (bool): Whether to load 3D semantic masks for points.
wuyuefeng's avatar
wuyuefeng committed
802
            Defaults to False.
803
804
805
806
807
        with_bbox (bool): Whether to load 2D boxes. Defaults to False.
        with_label (bool): Whether to load 2D labels. Defaults to False.
        with_mask (bool): Whether to load 2D instance masks. Defaults to False.
        with_seg (bool): Whether to load 2D semantic masks. Defaults to False.
        with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False.
808
809
        with_panoptic_3d (bool): Whether to load 3D panoptic masks for points.
            Defaults to False.
810
811
        poly2mask (bool): Whether to convert polygon annotations to bitmasks.
            Defaults to True.
812
813
        seg_3d_dtype (str): String of dtype of 3D semantic masks.
            Defaults to 'np.int64'.
814
815
816
817
        seg_offset (int): The offset to split semantic and instance labels from
            panoptic labels. Defaults to None.
        dataset_type (str): Type of dataset used for splitting semantic and
            instance labels. Defaults to None.
818
819
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
820
821
    """

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    def __init__(self,
                 with_bbox_3d: bool = True,
                 with_label_3d: bool = True,
                 with_attr_label: bool = False,
                 with_mask_3d: bool = False,
                 with_seg_3d: bool = False,
                 with_bbox: bool = False,
                 with_label: bool = False,
                 with_mask: bool = False,
                 with_seg: bool = False,
                 with_bbox_depth: bool = False,
                 with_panoptic_3d: bool = False,
                 poly2mask: bool = True,
                 seg_3d_dtype: str = 'np.int64',
                 seg_offset: int = None,
                 dataset_type: str = None,
                 backend_args: Optional[dict] = None) -> None:
wuyuefeng's avatar
wuyuefeng committed
839
        super().__init__(
jshilong's avatar
jshilong committed
840
841
842
843
844
            with_bbox=with_bbox,
            with_label=with_label,
            with_mask=with_mask,
            with_seg=with_seg,
            poly2mask=poly2mask,
845
            backend_args=backend_args)
wuyuefeng's avatar
wuyuefeng committed
846
        self.with_bbox_3d = with_bbox_3d
847
        self.with_bbox_depth = with_bbox_depth
wuyuefeng's avatar
wuyuefeng committed
848
        self.with_label_3d = with_label_3d
849
        self.with_attr_label = with_attr_label
wuyuefeng's avatar
wuyuefeng committed
850
851
        self.with_mask_3d = with_mask_3d
        self.with_seg_3d = with_seg_3d
852
        self.with_panoptic_3d = with_panoptic_3d
853
        self.seg_3d_dtype = eval(seg_3d_dtype)
854
855
        self.seg_offset = seg_offset
        self.dataset_type = dataset_type
wuyuefeng's avatar
wuyuefeng committed
856

jshilong's avatar
jshilong committed
857
858
859
    def _load_bboxes_3d(self, results: dict) -> dict:
        """Private function to move the 3D bounding box annotation from
        `ann_info` field to the root of `results`.
860
861
862
863
864
865
866

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box annotations.
        """
jshilong's avatar
jshilong committed
867

wuyuefeng's avatar
wuyuefeng committed
868
869
870
        results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
        return results

jshilong's avatar
jshilong committed
871
    def _load_bboxes_depth(self, results: dict) -> dict:
872
873
874
875
876
877
878
879
        """Private function to load 2.5D bounding box annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 2.5D bounding box annotations.
        """
jshilong's avatar
jshilong committed
880

881
        results['depths'] = results['ann_info']['depths']
jshilong's avatar
jshilong committed
882
        results['centers_2d'] = results['ann_info']['centers_2d']
883
884
        return results

jshilong's avatar
jshilong committed
885
    def _load_labels_3d(self, results: dict) -> dict:
886
887
888
889
890
891
892
893
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
jshilong's avatar
jshilong committed
894

wuyuefeng's avatar
wuyuefeng committed
895
896
897
        results['gt_labels_3d'] = results['ann_info']['gt_labels_3d']
        return results

jshilong's avatar
jshilong committed
898
    def _load_attr_labels(self, results: dict) -> dict:
899
900
901
902
903
904
905
906
907
908
909
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
        results['attr_labels'] = results['ann_info']['attr_labels']
        return results

jshilong's avatar
jshilong committed
910
    def _load_masks_3d(self, results: dict) -> dict:
911
912
913
914
915
916
917
918
        """Private function to load 3D mask annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D mask annotations.
        """
jshilong's avatar
jshilong committed
919
        pts_instance_mask_path = results['pts_instance_mask_path']
wuyuefeng's avatar
wuyuefeng committed
920
921

        try:
922
923
            mask_bytes = get(
                pts_instance_mask_path, backend_args=self.backend_args)
924
            pts_instance_mask = np.frombuffer(mask_bytes, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
925
        except ConnectionError:
926
            mmengine.check_file_exist(pts_instance_mask_path)
wuyuefeng's avatar
wuyuefeng committed
927
            pts_instance_mask = np.fromfile(
WRH's avatar
WRH committed
928
                pts_instance_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
929
930

        results['pts_instance_mask'] = pts_instance_mask
jshilong's avatar
jshilong committed
931
932
933
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
wuyuefeng's avatar
wuyuefeng committed
934
935
        return results

jshilong's avatar
jshilong committed
936
    def _load_semantic_seg_3d(self, results: dict) -> dict:
937
938
939
940
941
942
943
944
        """Private function to load 3D semantic segmentation annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing the semantic segmentation annotations.
        """
jshilong's avatar
jshilong committed
945
        pts_semantic_mask_path = results['pts_semantic_mask_path']
wuyuefeng's avatar
wuyuefeng committed
946
947

        try:
948
949
            mask_bytes = get(
                pts_semantic_mask_path, backend_args=self.backend_args)
wuyuefeng's avatar
wuyuefeng committed
950
            # add .copy() to fix read-only bug
951
952
            pts_semantic_mask = np.frombuffer(
                mask_bytes, dtype=self.seg_3d_dtype).copy()
wuyuefeng's avatar
wuyuefeng committed
953
        except ConnectionError:
954
            mmengine.check_file_exist(pts_semantic_mask_path)
wuyuefeng's avatar
wuyuefeng committed
955
            pts_semantic_mask = np.fromfile(
WRH's avatar
WRH committed
956
                pts_semantic_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
957

958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
        if self.dataset_type == 'semantickitti':
            pts_semantic_mask = pts_semantic_mask.astype(np.int64)
            pts_semantic_mask = pts_semantic_mask % self.seg_offset
        # nuScenes loads semantic and panoptic labels from different files.

        results['pts_semantic_mask'] = pts_semantic_mask

        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
        return results

    def _load_panoptic_3d(self, results: dict) -> dict:
        """Private function to load 3D panoptic segmentation annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing the panoptic segmentation annotations.
        """
        pts_panoptic_mask_path = results['pts_panoptic_mask_path']

        try:
982
983
            mask_bytes = get(
                pts_panoptic_mask_path, backend_args=self.backend_args)
984
985
986
987
988
989
990
991
992
993
994
995
996
997
            # add .copy() to fix read-only bug
            pts_panoptic_mask = np.frombuffer(
                mask_bytes, dtype=self.seg_3d_dtype).copy()
        except ConnectionError:
            mmengine.check_file_exist(pts_panoptic_mask_path)
            pts_panoptic_mask = np.fromfile(
                pts_panoptic_mask_path, dtype=np.int64)

        if self.dataset_type == 'semantickitti':
            pts_semantic_mask = pts_panoptic_mask.astype(np.int64)
            pts_semantic_mask = pts_semantic_mask % self.seg_offset
        elif self.dataset_type == 'nuscenes':
            pts_semantic_mask = pts_semantic_mask // self.seg_offset

wuyuefeng's avatar
wuyuefeng committed
998
        results['pts_semantic_mask'] = pts_semantic_mask
999
1000
1001
1002
1003

        # We can directly take panoptic labels as instance ids.
        pts_instance_mask = pts_panoptic_mask.astype(np.int64)
        results['pts_instance_mask'] = pts_instance_mask

jshilong's avatar
jshilong committed
1004
1005
1006
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
1007
            results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
wuyuefeng's avatar
wuyuefeng committed
1008
1009
        return results

zhangshilong's avatar
zhangshilong committed
1010
1011
1012
1013
1014
1015
1016
    def _load_bboxes(self, results: dict) -> None:
        """Private function to load bounding box annotations.

        The only difference is it remove the proceess for
        `ignore_flag`

        Args:
1017
1018
            results (dict): Result dict from :obj:`mmcv.BaseDataset`.

zhangshilong's avatar
zhangshilong committed
1019
1020
1021
1022
        Returns:
            dict: The dict contains loaded bounding box annotations.
        """

1023
        results['gt_bboxes'] = results['ann_info']['gt_bboxes']
zhangshilong's avatar
zhangshilong committed
1024
1025
1026
1027
1028

    def _load_labels(self, results: dict) -> None:
        """Private function to load label annotations.

        Args:
1029
            results (dict): Result dict from :obj :obj:`mmcv.BaseDataset`.
zhangshilong's avatar
zhangshilong committed
1030
1031
1032
1033

        Returns:
            dict: The dict contains loaded label annotations.
        """
1034
        results['gt_bboxes_labels'] = results['ann_info']['gt_bboxes_labels']
zhangshilong's avatar
zhangshilong committed
1035

jshilong's avatar
jshilong committed
1036
1037
    def transform(self, results: dict) -> dict:
        """Function to load multiple types annotations.
1038
1039
1040
1041
1042
1043

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box, label, mask and
jshilong's avatar
jshilong committed
1044
            semantic segmentation annotations.
1045
        """
jshilong's avatar
jshilong committed
1046
        results = super().transform(results)
wuyuefeng's avatar
wuyuefeng committed
1047
1048
        if self.with_bbox_3d:
            results = self._load_bboxes_3d(results)
1049
1050
        if self.with_bbox_depth:
            results = self._load_bboxes_depth(results)
wuyuefeng's avatar
wuyuefeng committed
1051
1052
        if self.with_label_3d:
            results = self._load_labels_3d(results)
1053
1054
        if self.with_attr_label:
            results = self._load_attr_labels(results)
1055
1056
        if self.with_panoptic_3d:
            results = self._load_panoptic_3d(results)
wuyuefeng's avatar
wuyuefeng committed
1057
1058
1059
1060
1061
1062
        if self.with_mask_3d:
            results = self._load_masks_3d(results)
        if self.with_seg_3d:
            results = self._load_semantic_seg_3d(results)
        return results

1063
    def __repr__(self) -> str:
1064
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
1065
1066
        indent_str = '    '
        repr_str = self.__class__.__name__ + '(\n'
liyinhao's avatar
liyinhao committed
1067
1068
        repr_str += f'{indent_str}with_bbox_3d={self.with_bbox_3d}, '
        repr_str += f'{indent_str}with_label_3d={self.with_label_3d}, '
1069
        repr_str += f'{indent_str}with_attr_label={self.with_attr_label}, '
liyinhao's avatar
liyinhao committed
1070
1071
        repr_str += f'{indent_str}with_mask_3d={self.with_mask_3d}, '
        repr_str += f'{indent_str}with_seg_3d={self.with_seg_3d}, '
1072
        repr_str += f'{indent_str}with_panoptic_3d={self.with_panoptic_3d}, '
liyinhao's avatar
liyinhao committed
1073
1074
1075
1076
        repr_str += f'{indent_str}with_bbox={self.with_bbox}, '
        repr_str += f'{indent_str}with_label={self.with_label}, '
        repr_str += f'{indent_str}with_mask={self.with_mask}, '
        repr_str += f'{indent_str}with_seg={self.with_seg}, '
1077
        repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
wuyuefeng's avatar
wuyuefeng committed
1078
        repr_str += f'{indent_str}poly2mask={self.poly2mask})'
1079
1080
        repr_str += f'{indent_str}seg_offset={self.seg_offset})'

wuyuefeng's avatar
wuyuefeng committed
1081
        return repr_str
1082
1083
1084


@TRANSFORMS.register_module()
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
class LidarDet3DInferencerLoader(BaseTransform):
    """Load point cloud in the Inferencer's pipeline.

    Added keys:
      - points
      - timestamp
      - axis_align_matrix
      - box_type_3d
      - box_mode_3d
    """

    def __init__(self, coord_type='LIDAR', **kwargs) -> None:
        super().__init__()
        self.from_file = TRANSFORMS.build(
            dict(type='LoadPointsFromFile', coord_type=coord_type, **kwargs))
        self.from_ndarray = TRANSFORMS.build(
            dict(type='LoadPointsFromDict', coord_type=coord_type, **kwargs))
        self.box_type_3d, self.box_mode_3d = get_box_type(coord_type)

    def transform(self, single_input: dict) -> dict:
        """Transform function to add image meta information.
        Args:
            single_input (dict): Single input.

        Returns:
            dict: The dict contains loaded image and meta information.
        """
        assert 'points' in single_input, "key 'points' must be in input dict"
        if isinstance(single_input['points'], str):
            inputs = dict(
                lidar_points=dict(lidar_path=single_input['points']),
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        elif isinstance(single_input['points'], np.ndarray):
            inputs = dict(
                points=single_input['points'],
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        else:
            raise ValueError('Unsupported input points type: '
                             f"{type(single_input['points'])}")

        if 'points' in inputs:
            return self.from_ndarray(inputs)
        return self.from_file(inputs)


@TRANSFORMS.register_module()
class MonoDet3DInferencerLoader(BaseTransform):
1140
1141
1142
1143
    """Load an image from ``results['images']['CAMX']['img']``. Similar with
    :obj:`LoadImageFromFileMono3D`, but the image has been loaded as
    :obj:`np.ndarray` in ``results['images']['CAMX']['img']``.

1144
1145
1146
1147
1148
1149
    Added keys:
      - img
      - cam2img
      - box_type_3d
      - box_mode_3d

1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    """

    def __init__(self, **kwargs) -> None:
        super().__init__()
        self.from_file = TRANSFORMS.build(
            dict(type='LoadImageFromFileMono3D', **kwargs))
        self.from_ndarray = TRANSFORMS.build(
            dict(type='LoadImageFromNDArray', **kwargs))

    def transform(self, single_input: dict) -> dict:
        """Transform function to add image meta information.

        Args:
            single_input (dict): Result dict with Webcam read image in
                ``results['images']['CAMX']['img']``.
        Returns:
            dict: The dict contains loaded image and meta information.
        """
        box_type_3d, box_mode_3d = get_box_type('camera')
1169
1170
        assert 'calib' in single_input and 'img' in single_input, \
            "key 'calib' and 'img' must be in input dict"
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        if isinstance(single_input['calib'], str):
            calib_path = single_input['calib']
            with open(calib_path, 'r') as f:
                lines = f.readlines()
            cam2img = np.array([
                float(info) for info in lines[0].split(' ')[0:16]
            ]).reshape([4, 4])
        elif isinstance(single_input['calib'], np.ndarray):
            cam2img = single_input['calib']
        else:
1181
1182
            raise ValueError('Unsupported input calib type: '
                             f"{type(single_input['calib'])}")
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197

        if isinstance(single_input['img'], str):
            inputs = dict(
                images=dict(
                    CAM_FRONT=dict(
                        img_path=single_input['img'], cam2img=cam2img)),
                box_mode_3d=box_mode_3d,
                box_type_3d=box_type_3d)
        elif isinstance(single_input['img'], np.ndarray):
            inputs = dict(
                img=single_input['img'],
                cam2img=cam2img,
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
        else:
1198
1199
            raise ValueError('Unsupported input image type: '
                             f"{type(single_input['img'])}")
1200
1201
1202
1203

        if 'img' in inputs:
            return self.from_ndarray(inputs)
        return self.from_file(inputs)
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319


@TRANSFORMS.register_module()
class MultiModalityDet3DInferencerLoader(BaseTransform):
    """Load point cloud and image in the Inferencer's pipeline.

    Added keys:
      - points
      - img
      - cam2img
      - lidar2cam
      - lidar2img
      - timestamp
      - axis_align_matrix
      - box_type_3d
      - box_mode_3d
    """

    def __init__(self, load_point_args: dict, load_img_args: dict) -> None:
        super().__init__()
        self.points_from_file = TRANSFORMS.build(
            dict(type='LoadPointsFromFile', **load_point_args))
        self.points_from_ndarray = TRANSFORMS.build(
            dict(type='LoadPointsFromDict', **load_point_args))
        coord_type = load_point_args['coord_type']
        self.box_type_3d, self.box_mode_3d = get_box_type(coord_type)

        self.imgs_from_file = TRANSFORMS.build(
            dict(type='LoadImageFromFile', **load_img_args))
        self.imgs_from_ndarray = TRANSFORMS.build(
            dict(type='LoadImageFromNDArray', **load_img_args))

    def transform(self, single_input: dict) -> dict:
        """Transform function to add image meta information.
        Args:
            single_input (dict): Single input.

        Returns:
            dict: The dict contains loaded image, point cloud and meta
            information.
        """
        assert 'points' in single_input and 'img' in single_input and \
            'calib' in single_input, "key 'points', 'img' and 'calib' must be "
        f'in input dict, but got {single_input}'
        if isinstance(single_input['points'], str):
            inputs = dict(
                lidar_points=dict(lidar_path=single_input['points']),
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        elif isinstance(single_input['points'], np.ndarray):
            inputs = dict(
                points=single_input['points'],
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        else:
            raise ValueError('Unsupported input points type: '
                             f"{type(single_input['points'])}")

        if 'points' in inputs:
            points_inputs = self.points_from_ndarray(inputs)
        else:
            points_inputs = self.points_from_file(inputs)

        multi_modality_inputs = points_inputs

        box_type_3d, box_mode_3d = get_box_type('lidar')
        if isinstance(single_input['calib'], str):
            calib = mmengine.load(single_input['calib'])

        elif isinstance(single_input['calib'], dict):
            calib = single_input['calib']
        else:
            raise ValueError('Unsupported input calib type: '
                             f"{type(single_input['calib'])}")

        cam2img = np.asarray(calib['cam2img'], dtype=np.float32)
        lidar2cam = np.asarray(calib['lidar2cam'], dtype=np.float32)
        if 'lidar2cam' in calib:
            lidar2img = np.asarray(calib['lidar2img'], dtype=np.float32)
        else:
            lidar2img = cam2img @ lidar2cam

        if isinstance(single_input['img'], str):
            inputs = dict(
                img_path=single_input['img'],
                cam2img=cam2img,
                lidar2img=lidar2img,
                lidar2cam=lidar2cam,
                box_mode_3d=box_mode_3d,
                box_type_3d=box_type_3d)
        elif isinstance(single_input['img'], np.ndarray):
            inputs = dict(
                img=single_input['img'],
                cam2img=cam2img,
                lidar2img=lidar2img,
                lidar2cam=lidar2cam,
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
        else:
            raise ValueError('Unsupported input image type: '
                             f"{type(single_input['img'])}")

        if isinstance(single_input['img'], np.ndarray):
            imgs_inputs = self.imgs_from_ndarray(inputs)
        else:
            imgs_inputs = self.imgs_from_file(inputs)

        multi_modality_inputs.update(imgs_inputs)

        return multi_modality_inputs