loading.py 43.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import copy
from typing import List, Optional, Union
4

zhangwenwei's avatar
zhangwenwei committed
5
import mmcv
6
import mmengine
7
import mmengine.fileio as fileio
zhangwenwei's avatar
zhangwenwei committed
8
import numpy as np
9
from mmcv.transforms import LoadImageFromFile
10
from mmcv.transforms.base import BaseTransform
11
from mmdet.datasets.transforms import LoadAnnotations
zhangwenwei's avatar
zhangwenwei committed
12

13
from mmdet3d.registry import TRANSFORMS
14
from mmdet3d.structures.bbox_3d import get_box_type
zhangshilong's avatar
zhangshilong committed
15
from mmdet3d.structures.points import BasePoints, get_points_type
zhangwenwei's avatar
zhangwenwei committed
16
17


18
@TRANSFORMS.register_module()
19
class LoadMultiViewImageFromFiles(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
20
    """Load multi channel images from a list of separate channel files.
zhangwenwei's avatar
zhangwenwei committed
21

liyinhao's avatar
liyinhao committed
22
23
24
    Expects results['img_filename'] to be a list of filenames.

    Args:
25
        to_float32 (bool): Whether to convert the img to float32.
liyinhao's avatar
liyinhao committed
26
            Defaults to False.
27
28
29
30
31
32
33
34
35
        color_type (str): Color type of the file. Defaults to 'unchanged'.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmengine.fileio.FileClient` for details.
            Defaults to dict(backend='disk').
        num_views (int): Number of view in a frame. Defaults to 5.
        num_ref_frames (int): Number of frame in loading. Defaults to -1.
        test_mode (bool): Whether is test mode in loading. Defaults to False.
        set_default_scale (bool): Whether to set default scale.
            Defaults to True.
zhangwenwei's avatar
zhangwenwei committed
36
    """
zhangwenwei's avatar
zhangwenwei committed
37

38
39
    def __init__(self,
                 to_float32: bool = False,
40
41
42
43
44
45
                 color_type: str = 'unchanged',
                 file_client_args: dict = dict(backend='disk'),
                 num_views: int = 5,
                 num_ref_frames: int = -1,
                 test_mode: bool = False,
                 set_default_scale: bool = True) -> None:
zhangwenwei's avatar
zhangwenwei committed
46
47
        self.to_float32 = to_float32
        self.color_type = color_type
48
49
50
51
52
53
54
55
56
57
58
        self.file_client_args = file_client_args.copy()
        self.file_client = None
        self.num_views = num_views
        # num_ref_frames is used for multi-sweep loading
        self.num_ref_frames = num_ref_frames
        # when test_mode=False, we randomly select previous frames
        # otherwise, select the earliest one
        self.test_mode = test_mode
        self.set_default_scale = set_default_scale

    def transform(self, results: dict) -> Optional[dict]:
59
60
61
62
63
64
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
65
            dict: The result dict containing the multi-view image data.
66
            Added keys and values are described below.
67
68
69
70
71
72
73
74
75

                - filename (str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        # TODO: consider split the multi-sweep part out of this pipeline
        # Derive the mask and transform for loading of multi-sweep data
        if self.num_ref_frames > 0:
            # init choice with the current frame
            init_choice = np.array([0], dtype=np.int64)
            num_frames = len(results['img_filename']) // self.num_views - 1
            if num_frames == 0:  # no previous frame, then copy cur frames
                choices = np.random.choice(
                    1, self.num_ref_frames, replace=True)
            elif num_frames >= self.num_ref_frames:
                # NOTE: suppose the info is saved following the order
                # from latest to earlier frames
                if self.test_mode:
                    choices = np.arange(num_frames - self.num_ref_frames,
                                        num_frames) + 1
                # NOTE: +1 is for selecting previous frames
                else:
                    choices = np.random.choice(
                        num_frames, self.num_ref_frames, replace=False) + 1
            elif num_frames > 0 and num_frames < self.num_ref_frames:
                if self.test_mode:
                    base_choices = np.arange(num_frames) + 1
                    random_choices = np.random.choice(
                        num_frames,
                        self.num_ref_frames - num_frames,
                        replace=True) + 1
                    choices = np.concatenate([base_choices, random_choices])
                else:
                    choices = np.random.choice(
                        num_frames, self.num_ref_frames, replace=True) + 1
            else:
                raise NotImplementedError
            choices = np.concatenate([init_choice, choices])
            select_filename = []
            for choice in choices:
                select_filename += results['img_filename'][choice *
                                                           self.num_views:
                                                           (choice + 1) *
                                                           self.num_views]
            results['img_filename'] = select_filename
            for key in ['cam2img', 'lidar2cam']:
                if key in results:
                    select_results = []
                    for choice in choices:
                        select_results += results[key][choice *
                                                       self.num_views:(choice +
                                                                       1) *
                                                       self.num_views]
                    results[key] = select_results
            for key in ['ego2global']:
                if key in results:
                    select_results = []
                    for choice in choices:
                        select_results += [results[key][choice]]
                    results[key] = select_results
            # Transform lidar2cam to
            # [cur_lidar]2[prev_img] and [cur_lidar]2[prev_cam]
            for key in ['lidar2cam']:
                if key in results:
                    # only change matrices of previous frames
                    for choice_idx in range(1, len(choices)):
                        pad_prev_ego2global = np.eye(4)
                        prev_ego2global = results['ego2global'][choice_idx]
                        pad_prev_ego2global[:prev_ego2global.
                                            shape[0], :prev_ego2global.
                                            shape[1]] = prev_ego2global
                        pad_cur_ego2global = np.eye(4)
                        cur_ego2global = results['ego2global'][0]
                        pad_cur_ego2global[:cur_ego2global.
                                           shape[0], :cur_ego2global.
                                           shape[1]] = cur_ego2global
                        cur2prev = np.linalg.inv(pad_prev_ego2global).dot(
                            pad_cur_ego2global)
                        for result_idx in range(choice_idx * self.num_views,
                                                (choice_idx + 1) *
                                                self.num_views):
                            results[key][result_idx] = \
                                results[key][result_idx].dot(cur2prev)
        # Support multi-view images with different shapes
        # TODO: record the origin shape and padded shape
        filename, cam2img, lidar2cam = [], [], []
        for _, cam_item in results['images'].items():
            filename.append(cam_item['img_path'])
            cam2img.append(cam_item['cam2img'])
            lidar2cam.append(cam_item['lidar2cam'])
        results['filename'] = filename
        results['cam2img'] = cam2img
        results['lidar2cam'] = lidar2cam

        results['ori_cam2img'] = copy.deepcopy(results['cam2img'])

        if self.file_client is None:
168
            self.file_client = mmengine.FileClient(**self.file_client_args)
169

170
        # img is of shape (h, w, c, num_views)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        # h and w can be different for different views
        img_bytes = [self.file_client.get(name) for name in filename]
        imgs = [
            mmcv.imfrombytes(img_byte, flag=self.color_type)
            for img_byte in img_bytes
        ]
        # handle the image with different shape
        img_shapes = np.stack([img.shape for img in imgs], axis=0)
        img_shape_max = np.max(img_shapes, axis=0)
        img_shape_min = np.min(img_shapes, axis=0)
        assert img_shape_min[-1] == img_shape_max[-1]
        if not np.all(img_shape_max == img_shape_min):
            pad_shape = img_shape_max[:2]
        else:
            pad_shape = None
        if pad_shape is not None:
            imgs = [
                mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs
            ]
        img = np.stack(imgs, axis=-1)
zhangwenwei's avatar
zhangwenwei committed
191
192
        if self.to_float32:
            img = img.astype(np.float32)
193

zhangwenwei's avatar
zhangwenwei committed
194
        results['filename'] = filename
195
        # unravel to list, see `DefaultFormatBundle` in formating.py
196
197
        # which will transpose each image separately and then stack into array
        results['img'] = [img[..., i] for i in range(img.shape[-1])]
198
199
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
zhangwenwei's avatar
zhangwenwei committed
200
        # Set initial values for default meta_keys
201
        results['pad_shape'] = img.shape[:2]
202
203
        if self.set_default_scale:
            results['scale_factor'] = 1.0
zhangwenwei's avatar
zhangwenwei committed
204
205
206
207
208
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
209
210
        results['num_views'] = self.num_views
        results['num_ref_frames'] = self.num_ref_frames
zhangwenwei's avatar
zhangwenwei committed
211
212
        return results

213
    def __repr__(self) -> str:
214
        """str: Return a string that describes the module."""
215
216
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32}, '
217
218
219
220
        repr_str += f"color_type='{self.color_type}', "
        repr_str += f'num_views={self.num_views}, '
        repr_str += f'num_ref_frames={self.num_ref_frames}, '
        repr_str += f'test_mode={self.test_mode})'
221
        return repr_str
zhangwenwei's avatar
zhangwenwei committed
222
223


224
@TRANSFORMS.register_module()
225
226
227
228
229
class LoadImageFromFileMono3D(LoadImageFromFile):
    """Load an image from file in monocular 3D object detection. Compared to 2D
    detection, additional camera parameters need to be loaded.

    Args:
230
        kwargs (dict): Arguments are the same as those in
231
232
233
            :class:`LoadImageFromFile`.
    """

ZCMax's avatar
ZCMax committed
234
    def transform(self, results: dict) -> dict:
235
236
237
238
239
240
241
242
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmdet.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """
ZCMax's avatar
ZCMax committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        # TODO: load different camera image from data info,
        # for kitti dataset, we load 'CAM2' image.
        # for nuscenes dataset, we load 'CAM_FRONT' image.

        if 'CAM2' in results['images']:
            filename = results['images']['CAM2']['img_path']
            results['cam2img'] = results['images']['CAM2']['cam2img']
        elif len(list(results['images'].keys())) == 1:
            camera_type = list(results['images'].keys())[0]
            filename = results['images'][camera_type]['img_path']
            results['cam2img'] = results['images'][camera_type]['cam2img']
        else:
            raise NotImplementedError(
                'Currently we only support load image from kitti and'
                'nuscenes datasets')

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        try:
            if self.file_client_args is not None:
                file_client = fileio.FileClient.infer_client(
                    self.file_client_args, filename)
                img_bytes = file_client.get(filename)
            else:
                img_bytes = fileio.get(
                    filename, backend_args=self.backend_args)
            img = mmcv.imfrombytes(
                img_bytes, flag=self.color_type, backend=self.imdecode_backend)
        except Exception as e:
            if self.ignore_empty:
                return None
            else:
                raise e
ZCMax's avatar
ZCMax committed
274
275
276
277
278
279
280
        if self.to_float32:
            img = img.astype(np.float32)

        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]

281
282
283
        return results


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
    """Load an image from ``results['img']``.
    Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
    :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
    from webcam.
    Required Keys:
    - img
    Modified Keys:
    - img
    - img_path
    - img_shape
    - ori_shape
    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
    """

    def transform(self, results: dict) -> dict:
        """Transform function to add image meta information.

        Args:
            results (dict): Result dict with Webcam read image in
                ``results['img']``.
        Returns:
            dict: The dict contains loaded image and meta information.
        """

        img = results['img']
        if self.to_float32:
            img = img.astype(np.float32)

        results['img_path'] = None
        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results


324
@TRANSFORMS.register_module()
VVsssssk's avatar
VVsssssk committed
325
class LoadPointsFromMultiSweeps(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
326
    """Load points from multiple sweeps.
zhangwenwei's avatar
zhangwenwei committed
327

zhangwenwei's avatar
zhangwenwei committed
328
329
330
    This is usually used for nuScenes dataset to utilize previous sweeps.

    Args:
331
332
333
334
335
336
337
        sweeps_num (int): Number of sweeps. Defaults to 10.
        load_dim (int): Dimension number of the loaded points. Defaults to 5.
        use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmengine.fileio.FileClient` for details.
            Defaults to dict(backend='disk').
        pad_empty_sweeps (bool): Whether to repeat keyframe when
338
            sweeps is empty. Defaults to False.
339
340
341
        remove_close (bool): Whether to remove close points. Defaults to False.
        test_mode (bool): If `test_mode=True`, it will not randomly sample
            sweeps but select the nearest N frames. Defaults to False.
zhangwenwei's avatar
zhangwenwei committed
342
343
    """

344
345
346
347
348
349
350
351
    def __init__(self,
                 sweeps_num: int = 10,
                 load_dim: int = 5,
                 use_dim: List[int] = [0, 1, 2, 4],
                 file_client_args: dict = dict(backend='disk'),
                 pad_empty_sweeps: bool = False,
                 remove_close: bool = False,
                 test_mode: bool = False) -> None:
zhangwenwei's avatar
zhangwenwei committed
352
        self.load_dim = load_dim
zhangwenwei's avatar
zhangwenwei committed
353
        self.sweeps_num = sweeps_num
354
355
356
357
        if isinstance(use_dim, int):
            use_dim = list(range(use_dim))
        assert max(use_dim) < load_dim, \
            f'Expect all used dimensions < {load_dim}, got {use_dim}'
358
        self.use_dim = use_dim
zhangwenwei's avatar
zhangwenwei committed
359
        self.file_client_args = file_client_args.copy()
360
        self.file_client = mmengine.FileClient(**self.file_client_args)
361
362
363
        self.pad_empty_sweeps = pad_empty_sweeps
        self.remove_close = remove_close
        self.test_mode = test_mode
zhangwenwei's avatar
zhangwenwei committed
364

365
    def _load_points(self, pts_filename: str) -> np.ndarray:
366
367
368
369
370
371
372
373
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
zhangwenwei's avatar
zhangwenwei committed
374
        if self.file_client is None:
375
            self.file_client = mmengine.FileClient(**self.file_client_args)
zhangwenwei's avatar
zhangwenwei committed
376
377
378
379
        try:
            pts_bytes = self.file_client.get(pts_filename)
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
380
            mmengine.check_file_exist(pts_filename)
zhangwenwei's avatar
zhangwenwei committed
381
382
383
384
385
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
        return points
zhangwenwei's avatar
zhangwenwei committed
386

387
388
389
    def _remove_close(self,
                      points: Union[np.ndarray, BasePoints],
                      radius: float = 1.0) -> Union[np.ndarray, BasePoints]:
390
        """Remove point too close within a certain radius from origin.
391
392

        Args:
393
            points (np.ndarray | :obj:`BasePoints`): Sweep points.
394
            radius (float): Radius below which points are removed.
395
396
397
                Defaults to 1.0.

        Returns:
398
            np.ndarray | :obj:`BasePoints`: Points after removing.
399
        """
400
401
402
403
404
405
406
407
        if isinstance(points, np.ndarray):
            points_numpy = points
        elif isinstance(points, BasePoints):
            points_numpy = points.tensor.numpy()
        else:
            raise NotImplementedError
        x_filt = np.abs(points_numpy[:, 0]) < radius
        y_filt = np.abs(points_numpy[:, 1]) < radius
408
        not_close = np.logical_not(np.logical_and(x_filt, y_filt))
409
        return points[not_close]
410

411
    def transform(self, results: dict) -> dict:
412
413
414
        """Call function to load multi-sweep point clouds from files.

        Args:
415
            results (dict): Result dict containing multi-sweep point cloud
416
417
418
                filenames.

        Returns:
419
            dict: The result dict containing the multi-sweep points data.
420
            Updated key and value are described below.
421

422
                - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
423
                  cloud arrays.
424
        """
zhangwenwei's avatar
zhangwenwei committed
425
        points = results['points']
426
        points.tensor[:, 4] = 0
zhangwenwei's avatar
zhangwenwei committed
427
428
        sweep_points_list = [points]
        ts = results['timestamp']
VVsssssk's avatar
VVsssssk committed
429
430
431
432
433
434
435
        if 'lidar_sweeps' not in results:
            if self.pad_empty_sweeps:
                for i in range(self.sweeps_num):
                    if self.remove_close:
                        sweep_points_list.append(self._remove_close(points))
                    else:
                        sweep_points_list.append(points)
436
        else:
VVsssssk's avatar
VVsssssk committed
437
438
            if len(results['lidar_sweeps']) <= self.sweeps_num:
                choices = np.arange(len(results['lidar_sweeps']))
439
440
441
442
            elif self.test_mode:
                choices = np.arange(self.sweeps_num)
            else:
                choices = np.random.choice(
VVsssssk's avatar
VVsssssk committed
443
444
445
                    len(results['lidar_sweeps']),
                    self.sweeps_num,
                    replace=False)
446
            for idx in choices:
VVsssssk's avatar
VVsssssk committed
447
448
449
                sweep = results['lidar_sweeps'][idx]
                points_sweep = self._load_points(
                    sweep['lidar_points']['lidar_path'])
450
451
452
                points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
                if self.remove_close:
                    points_sweep = self._remove_close(points_sweep)
VVsssssk's avatar
VVsssssk committed
453
454
                # bc-breaking: Timestamp has divided 1e6 in pkl infos.
                sweep_ts = sweep['timestamp']
455
456
457
458
                lidar2sensor = np.array(sweep['lidar_points']['lidar2sensor'])
                points_sweep[:, :
                             3] = points_sweep[:, :3] @ lidar2sensor[:3, :3]
                points_sweep[:, :3] -= lidar2sensor[:3, 3]
459
                points_sweep[:, 4] = ts - sweep_ts
460
                points_sweep = points.new_point(points_sweep)
461
462
                sweep_points_list.append(points_sweep)

463
464
        points = points.cat(sweep_points_list)
        points = points[:, self.use_dim]
zhangwenwei's avatar
zhangwenwei committed
465
466
467
        results['points'] = points
        return results

468
    def __repr__(self) -> str:
469
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
470
        return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
wuyuefeng's avatar
wuyuefeng committed
471
472


473
@TRANSFORMS.register_module()
474
class PointSegClassMapping(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
475
476
    """Map original semantic class to valid category ids.

477
478
    Required Keys:

479
480
    - seg_label_mapping (np.ndarray)
    - pts_semantic_mask (np.ndarray)
481
482
483
484
485

    Added Keys:

    - points (np.float32)

wuyuefeng's avatar
wuyuefeng committed
486
487
488
489
    Map valid classes as 0~len(valid_cat_ids)-1 and
    others as len(valid_cat_ids).
    """

490
    def transform(self, results: dict) -> dict:
491
492
493
494
495
496
        """Call function to map original semantic class to valid category ids.

        Args:
            results (dict): Result dict containing point semantic masks.

        Returns:
497
            dict: The result dict containing the mapped category ids.
498
            Updated key and value are described below.
499
500
501

                - pts_semantic_mask (np.ndarray): Mapped semantic masks.
        """
wuyuefeng's avatar
wuyuefeng committed
502
503
504
        assert 'pts_semantic_mask' in results
        pts_semantic_mask = results['pts_semantic_mask']

505
506
507
        assert 'seg_label_mapping' in results
        label_mapping = results['seg_label_mapping']
        converted_pts_sem_mask = label_mapping[pts_semantic_mask]
wuyuefeng's avatar
wuyuefeng committed
508

509
        results['pts_semantic_mask'] = converted_pts_sem_mask
ZCMax's avatar
ZCMax committed
510
511
512
513
514
515
516

        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            assert 'pts_semantic_mask' in results['eval_ann_info']
            results['eval_ann_info']['pts_semantic_mask'] = \
                converted_pts_sem_mask

wuyuefeng's avatar
wuyuefeng committed
517
518
        return results

519
    def __repr__(self) -> str:
520
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
521
522
523
524
        repr_str = self.__class__.__name__
        return repr_str


525
@TRANSFORMS.register_module()
ZCMax's avatar
ZCMax committed
526
class NormalizePointsColor(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
527
    """Normalize color of points.
wuyuefeng's avatar
wuyuefeng committed
528
529
530
531
532

    Args:
        color_mean (list[float]): Mean color of the point cloud.
    """

ZCMax's avatar
ZCMax committed
533
    def __init__(self, color_mean: List[float]) -> None:
wuyuefeng's avatar
wuyuefeng committed
534
535
        self.color_mean = color_mean

ZCMax's avatar
ZCMax committed
536
    def transform(self, input_dict: dict) -> dict:
537
538
539
540
541
542
        """Call function to normalize color of points.

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
543
            dict: The result dict containing the normalized points.
544
            Updated key and value are described below.
545

546
                - points (:obj:`BasePoints`): Points after color normalization.
547
        """
ZCMax's avatar
ZCMax committed
548
        points = input_dict['points']
549
        assert points.attribute_dims is not None and \
550
551
               'color' in points.attribute_dims.keys(), \
               'Expect points have color attribute'
552
553
        if self.color_mean is not None:
            points.color = points.color - \
554
                           points.color.new_tensor(self.color_mean)
555
        points.color = points.color / 255.0
ZCMax's avatar
ZCMax committed
556
557
        input_dict['points'] = points
        return input_dict
wuyuefeng's avatar
wuyuefeng committed
558

559
    def __repr__(self) -> str:
560
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
561
        repr_str = self.__class__.__name__
562
        repr_str += f'(color_mean={self.color_mean})'
wuyuefeng's avatar
wuyuefeng committed
563
564
565
        return repr_str


566
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
567
class LoadPointsFromFile(BaseTransform):
wuyuefeng's avatar
wuyuefeng committed
568
569
    """Load Points From File.

jshilong's avatar
jshilong committed
570
571
572
573
574
575
576
577
578
    Required Keys:

    - lidar_points (dict)

        - lidar_path (str)

    Added Keys:

    - points (np.float32)
wuyuefeng's avatar
wuyuefeng committed
579
580

    Args:
581
582
        coord_type (str): The type of coordinates of points cloud.
            Available options includes:
583

584
585
586
            - 'LIDAR': Points in LiDAR coordinates.
            - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
            - 'CAMERA': Points in camera coordinates.
587
588
589
        load_dim (int): The dimension of the loaded points. Defaults to 6.
        use_dim (list[int] | int): Which dimensions of the points to use.
            Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
liyinhao's avatar
liyinhao committed
590
            or use_dim=[0, 1, 2, 3] to use the intensity dimension.
591
592
        shift_height (bool): Whether to use shifted height. Defaults to False.
        use_color (bool): Whether to use color features. Defaults to False.
593
594
        norm_intensity (bool): Whether to normlize the intensity. Defaults to
            False.
595
596
597
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmengine.fileio.FileClient` for details.
            Defaults to dict(backend='disk').
wuyuefeng's avatar
wuyuefeng committed
598
599
    """

jshilong's avatar
jshilong committed
600
601
602
603
    def __init__(
        self,
        coord_type: str,
        load_dim: int = 6,
604
        use_dim: Union[int, List[int]] = [0, 1, 2],
jshilong's avatar
jshilong committed
605
606
        shift_height: bool = False,
        use_color: bool = False,
607
        norm_intensity: bool = False,
jshilong's avatar
jshilong committed
608
609
        file_client_args: dict = dict(backend='disk')
    ) -> None:
wuyuefeng's avatar
wuyuefeng committed
610
        self.shift_height = shift_height
611
        self.use_color = use_color
wuyuefeng's avatar
wuyuefeng committed
612
613
614
615
        if isinstance(use_dim, int):
            use_dim = list(range(use_dim))
        assert max(use_dim) < load_dim, \
            f'Expect all used dimensions < {load_dim}, got {use_dim}'
616
        assert coord_type in ['CAMERA', 'LIDAR', 'DEPTH']
wuyuefeng's avatar
wuyuefeng committed
617

618
        self.coord_type = coord_type
wuyuefeng's avatar
wuyuefeng committed
619
620
        self.load_dim = load_dim
        self.use_dim = use_dim
621
        self.norm_intensity = norm_intensity
wuyuefeng's avatar
wuyuefeng committed
622
623
624
        self.file_client_args = file_client_args.copy()
        self.file_client = None

jshilong's avatar
jshilong committed
625
    def _load_points(self, pts_filename: str) -> np.ndarray:
626
627
628
629
630
631
632
633
        """Private function to load point clouds data.

        Args:
            pts_filename (str): Filename of point clouds data.

        Returns:
            np.ndarray: An array containing point clouds data.
        """
wuyuefeng's avatar
wuyuefeng committed
634
        if self.file_client is None:
635
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
636
637
638
639
        try:
            pts_bytes = self.file_client.get(pts_filename)
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
640
            mmengine.check_file_exist(pts_filename)
wuyuefeng's avatar
wuyuefeng committed
641
642
643
644
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)
645

wuyuefeng's avatar
wuyuefeng committed
646
647
        return points

jshilong's avatar
jshilong committed
648
649
    def transform(self, results: dict) -> dict:
        """Method to load points data from file.
650
651
652
653
654

        Args:
            results (dict): Result dict containing point clouds data.

        Returns:
655
            dict: The result dict containing the point clouds data.
656
            Added key and value are described below.
657

658
                - points (:obj:`BasePoints`): Point clouds data.
659
        """
jshilong's avatar
jshilong committed
660
661
        pts_file_path = results['lidar_points']['lidar_path']
        points = self._load_points(pts_file_path)
wuyuefeng's avatar
wuyuefeng committed
662
663
        points = points.reshape(-1, self.load_dim)
        points = points[:, self.use_dim]
664
665
666
667
        if self.norm_intensity:
            assert len(self.use_dim) >= 4, \
                f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}'  # noqa: E501
            points[:, 3] = np.tanh(points[:, 3])
668
        attribute_dims = None
wuyuefeng's avatar
wuyuefeng committed
669
670
671
672

        if self.shift_height:
            floor_height = np.percentile(points[:, 2], 0.99)
            height = points[:, 2] - floor_height
673
674
675
            points = np.concatenate(
                [points[:, :3],
                 np.expand_dims(height, 1), points[:, 3:]], 1)
676
677
            attribute_dims = dict(height=3)

678
679
680
681
682
683
684
685
686
687
688
        if self.use_color:
            assert len(self.use_dim) >= 6
            if attribute_dims is None:
                attribute_dims = dict()
            attribute_dims.update(
                dict(color=[
                    points.shape[1] - 3,
                    points.shape[1] - 2,
                    points.shape[1] - 1,
                ]))

689
690
691
        points_class = get_points_type(self.coord_type)
        points = points_class(
            points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
wuyuefeng's avatar
wuyuefeng committed
692
        results['points'] = points
693

wuyuefeng's avatar
wuyuefeng committed
694
695
        return results

696
    def __repr__(self) -> str:
697
        """str: Return a string that describes the module."""
liyinhao's avatar
liyinhao committed
698
        repr_str = self.__class__.__name__ + '('
699
700
701
702
703
        repr_str += f'shift_height={self.shift_height}, '
        repr_str += f'use_color={self.use_color}, '
        repr_str += f'file_client_args={self.file_client_args}, '
        repr_str += f'load_dim={self.load_dim}, '
        repr_str += f'use_dim={self.use_dim})'
wuyuefeng's avatar
wuyuefeng committed
704
705
706
        return repr_str


707
@TRANSFORMS.register_module()
708
709
710
class LoadPointsFromDict(LoadPointsFromFile):
    """Load Points From Dict."""

ChaimZhu's avatar
ChaimZhu committed
711
    def transform(self, results: dict) -> dict:
712
713
714
715
716
717
718
719
720
721
        """Convert the type of points from ndarray to corresponding
        `point_class`.

        Args:
            results (dict): input result. The value of key `points` is a
                numpy array.

        Returns:
            dict: The processed results.
        """
722
        assert 'points' in results
723
724
725
726
        points_class = get_points_type(self.coord_type)
        points = results['points']
        results['points'] = points_class(
            points, points_dim=points.shape[-1], attribute_dims=None)
727
728
729
        return results


730
@TRANSFORMS.register_module()
wuyuefeng's avatar
wuyuefeng committed
731
732
733
734
735
736
class LoadAnnotations3D(LoadAnnotations):
    """Load Annotations3D.

    Load instance mask and semantic mask of points and
    encapsulate the items into related fields.

jshilong's avatar
jshilong committed
737
738
739
    Required Keys:

    - ann_info (dict)
740

jshilong's avatar
jshilong committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
          :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
          3D ground truth bboxes. Only when `with_bbox_3d` is True
        - gt_labels_3d (np.int64): Labels of ground truths.
          Only when `with_label_3d` is True.
        - gt_bboxes (np.float32): 2D ground truth bboxes.
          Only when `with_bbox` is True.
        - gt_labels (np.ndarray): Labels of ground truths.
          Only when `with_label` is True.
        - depths (np.ndarray): Only when
          `with_bbox_depth` is True.
        - centers_2d (np.ndarray): Only when
          `with_bbox_depth` is True.
        - attr_labels (np.ndarray): Attribute labels of instances.
          Only when `with_attr_label` is True.

    - pts_instance_mask_path (str): Path of instance mask file.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask_path (str): Path of semantic mask file.
760
      Only when `with_seg_3d` is True.
jshilong's avatar
jshilong committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783

    Added Keys:

    - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
      :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
      3D ground truth bboxes. Only when `with_bbox_3d` is True
    - gt_labels_3d (np.int64): Labels of ground truths.
      Only when `with_label_3d` is True.
    - gt_bboxes (np.float32): 2D ground truth bboxes.
      Only when `with_bbox` is True.
    - gt_labels (np.int64): Labels of ground truths.
      Only when `with_label` is True.
    - depths (np.float32): Only when
      `with_bbox_depth` is True.
    - centers_2d (np.ndarray): Only when
      `with_bbox_depth` is True.
    - attr_labels (np.int64): Attribute labels of instances.
      Only when `with_attr_label` is True.
    - pts_instance_mask (np.int64): Instance mask of each point.
      Only when `with_mask_3d` is True.
    - pts_semantic_mask (np.int64): Semantic mask of each point.
      Only when `with_seg_3d` is True.

wuyuefeng's avatar
wuyuefeng committed
784
    Args:
785
786
787
        with_bbox_3d (bool): Whether to load 3D boxes. Defaults to True.
        with_label_3d (bool): Whether to load 3D labels. Defaults to True.
        with_attr_label (bool): Whether to load attribute label.
wuyuefeng's avatar
wuyuefeng committed
788
            Defaults to False.
789
        with_mask_3d (bool): Whether to load 3D instance masks for points.
wuyuefeng's avatar
wuyuefeng committed
790
            Defaults to False.
791
        with_seg_3d (bool): Whether to load 3D semantic masks for points.
wuyuefeng's avatar
wuyuefeng committed
792
            Defaults to False.
793
794
795
796
797
798
799
800
801
802
803
        with_bbox (bool): Whether to load 2D boxes. Defaults to False.
        with_label (bool): Whether to load 2D labels. Defaults to False.
        with_mask (bool): Whether to load 2D instance masks. Defaults to False.
        with_seg (bool): Whether to load 2D semantic masks. Defaults to False.
        with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False.
        poly2mask (bool): Whether to convert polygon annotations to bitmasks.
            Defaults to True.
        seg_3d_dtype (dtype): Dtype of 3D semantic masks. Defaults to int64.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmengine.fileio.FileClient` for details.
            Defaults to dict(backend='disk').
wuyuefeng's avatar
wuyuefeng committed
804
805
    """

jshilong's avatar
jshilong committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
    def __init__(
        self,
        with_bbox_3d: bool = True,
        with_label_3d: bool = True,
        with_attr_label: bool = False,
        with_mask_3d: bool = False,
        with_seg_3d: bool = False,
        with_bbox: bool = False,
        with_label: bool = False,
        with_mask: bool = False,
        with_seg: bool = False,
        with_bbox_depth: bool = False,
        poly2mask: bool = True,
        seg_3d_dtype: np.dtype = np.int64,
        file_client_args: dict = dict(backend='disk')
    ) -> None:
wuyuefeng's avatar
wuyuefeng committed
822
        super().__init__(
jshilong's avatar
jshilong committed
823
824
825
826
827
            with_bbox=with_bbox,
            with_label=with_label,
            with_mask=with_mask,
            with_seg=with_seg,
            poly2mask=poly2mask,
wuyuefeng's avatar
wuyuefeng committed
828
829
            file_client_args=file_client_args)
        self.with_bbox_3d = with_bbox_3d
830
        self.with_bbox_depth = with_bbox_depth
wuyuefeng's avatar
wuyuefeng committed
831
        self.with_label_3d = with_label_3d
832
        self.with_attr_label = with_attr_label
wuyuefeng's avatar
wuyuefeng committed
833
834
        self.with_mask_3d = with_mask_3d
        self.with_seg_3d = with_seg_3d
835
        self.seg_3d_dtype = seg_3d_dtype
836
        self.file_client = None
wuyuefeng's avatar
wuyuefeng committed
837

jshilong's avatar
jshilong committed
838
839
840
    def _load_bboxes_3d(self, results: dict) -> dict:
        """Private function to move the 3D bounding box annotation from
        `ann_info` field to the root of `results`.
841
842
843
844
845
846
847

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box annotations.
        """
jshilong's avatar
jshilong committed
848

wuyuefeng's avatar
wuyuefeng committed
849
850
851
        results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
        return results

jshilong's avatar
jshilong committed
852
    def _load_bboxes_depth(self, results: dict) -> dict:
853
854
855
856
857
858
859
860
        """Private function to load 2.5D bounding box annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 2.5D bounding box annotations.
        """
jshilong's avatar
jshilong committed
861

862
        results['depths'] = results['ann_info']['depths']
jshilong's avatar
jshilong committed
863
        results['centers_2d'] = results['ann_info']['centers_2d']
864
865
        return results

jshilong's avatar
jshilong committed
866
    def _load_labels_3d(self, results: dict) -> dict:
867
868
869
870
871
872
873
874
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
jshilong's avatar
jshilong committed
875

wuyuefeng's avatar
wuyuefeng committed
876
877
878
        results['gt_labels_3d'] = results['ann_info']['gt_labels_3d']
        return results

jshilong's avatar
jshilong committed
879
    def _load_attr_labels(self, results: dict) -> dict:
880
881
882
883
884
885
886
887
888
889
890
        """Private function to load label annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded label annotations.
        """
        results['attr_labels'] = results['ann_info']['attr_labels']
        return results

jshilong's avatar
jshilong committed
891
    def _load_masks_3d(self, results: dict) -> dict:
892
893
894
895
896
897
898
899
        """Private function to load 3D mask annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D mask annotations.
        """
jshilong's avatar
jshilong committed
900
        pts_instance_mask_path = results['pts_instance_mask_path']
wuyuefeng's avatar
wuyuefeng committed
901
902

        if self.file_client is None:
903
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
904
905
        try:
            mask_bytes = self.file_client.get(pts_instance_mask_path)
906
            pts_instance_mask = np.frombuffer(mask_bytes, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
907
        except ConnectionError:
908
            mmengine.check_file_exist(pts_instance_mask_path)
wuyuefeng's avatar
wuyuefeng committed
909
            pts_instance_mask = np.fromfile(
WRH's avatar
WRH committed
910
                pts_instance_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
911
912

        results['pts_instance_mask'] = pts_instance_mask
jshilong's avatar
jshilong committed
913
914
915
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
wuyuefeng's avatar
wuyuefeng committed
916
917
        return results

jshilong's avatar
jshilong committed
918
    def _load_semantic_seg_3d(self, results: dict) -> dict:
919
920
921
922
923
924
925
926
        """Private function to load 3D semantic segmentation annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing the semantic segmentation annotations.
        """
jshilong's avatar
jshilong committed
927
        pts_semantic_mask_path = results['pts_semantic_mask_path']
wuyuefeng's avatar
wuyuefeng committed
928
929

        if self.file_client is None:
930
            self.file_client = mmengine.FileClient(**self.file_client_args)
wuyuefeng's avatar
wuyuefeng committed
931
932
933
        try:
            mask_bytes = self.file_client.get(pts_semantic_mask_path)
            # add .copy() to fix read-only bug
934
935
            pts_semantic_mask = np.frombuffer(
                mask_bytes, dtype=self.seg_3d_dtype).copy()
wuyuefeng's avatar
wuyuefeng committed
936
        except ConnectionError:
937
            mmengine.check_file_exist(pts_semantic_mask_path)
wuyuefeng's avatar
wuyuefeng committed
938
            pts_semantic_mask = np.fromfile(
WRH's avatar
WRH committed
939
                pts_semantic_mask_path, dtype=np.int64)
wuyuefeng's avatar
wuyuefeng committed
940
941

        results['pts_semantic_mask'] = pts_semantic_mask
jshilong's avatar
jshilong committed
942
943
944
        # 'eval_ann_info' will be passed to evaluator
        if 'eval_ann_info' in results:
            results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
wuyuefeng's avatar
wuyuefeng committed
945
946
        return results

zhangshilong's avatar
zhangshilong committed
947
948
949
950
951
952
953
    def _load_bboxes(self, results: dict) -> None:
        """Private function to load bounding box annotations.

        The only difference is it remove the proceess for
        `ignore_flag`

        Args:
954
955
            results (dict): Result dict from :obj:`mmcv.BaseDataset`.

zhangshilong's avatar
zhangshilong committed
956
957
958
959
        Returns:
            dict: The dict contains loaded bounding box annotations.
        """

960
        results['gt_bboxes'] = results['ann_info']['gt_bboxes']
zhangshilong's avatar
zhangshilong committed
961
962
963
964
965

    def _load_labels(self, results: dict) -> None:
        """Private function to load label annotations.

        Args:
966
            results (dict): Result dict from :obj :obj:`mmcv.BaseDataset`.
zhangshilong's avatar
zhangshilong committed
967
968
969
970

        Returns:
            dict: The dict contains loaded label annotations.
        """
971
        results['gt_bboxes_labels'] = results['ann_info']['gt_bboxes_labels']
zhangshilong's avatar
zhangshilong committed
972

jshilong's avatar
jshilong committed
973
974
    def transform(self, results: dict) -> dict:
        """Function to load multiple types annotations.
975
976
977
978
979
980

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box, label, mask and
jshilong's avatar
jshilong committed
981
            semantic segmentation annotations.
982
        """
jshilong's avatar
jshilong committed
983
        results = super().transform(results)
wuyuefeng's avatar
wuyuefeng committed
984
985
        if self.with_bbox_3d:
            results = self._load_bboxes_3d(results)
986
987
        if self.with_bbox_depth:
            results = self._load_bboxes_depth(results)
wuyuefeng's avatar
wuyuefeng committed
988
989
        if self.with_label_3d:
            results = self._load_labels_3d(results)
990
991
        if self.with_attr_label:
            results = self._load_attr_labels(results)
wuyuefeng's avatar
wuyuefeng committed
992
993
994
995
996
997
998
        if self.with_mask_3d:
            results = self._load_masks_3d(results)
        if self.with_seg_3d:
            results = self._load_semantic_seg_3d(results)

        return results

999
    def __repr__(self) -> str:
1000
        """str: Return a string that describes the module."""
wuyuefeng's avatar
wuyuefeng committed
1001
1002
        indent_str = '    '
        repr_str = self.__class__.__name__ + '(\n'
liyinhao's avatar
liyinhao committed
1003
1004
        repr_str += f'{indent_str}with_bbox_3d={self.with_bbox_3d}, '
        repr_str += f'{indent_str}with_label_3d={self.with_label_3d}, '
1005
        repr_str += f'{indent_str}with_attr_label={self.with_attr_label}, '
liyinhao's avatar
liyinhao committed
1006
1007
1008
1009
1010
1011
        repr_str += f'{indent_str}with_mask_3d={self.with_mask_3d}, '
        repr_str += f'{indent_str}with_seg_3d={self.with_seg_3d}, '
        repr_str += f'{indent_str}with_bbox={self.with_bbox}, '
        repr_str += f'{indent_str}with_label={self.with_label}, '
        repr_str += f'{indent_str}with_mask={self.with_mask}, '
        repr_str += f'{indent_str}with_seg={self.with_seg}, '
1012
        repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
wuyuefeng's avatar
wuyuefeng committed
1013
1014
        repr_str += f'{indent_str}poly2mask={self.poly2mask})'
        return repr_str
1015
1016
1017


@TRANSFORMS.register_module()
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
class LidarDet3DInferencerLoader(BaseTransform):
    """Load point cloud in the Inferencer's pipeline.

    Added keys:
      - points
      - timestamp
      - axis_align_matrix
      - box_type_3d
      - box_mode_3d
    """

    def __init__(self, coord_type='LIDAR', **kwargs) -> None:
        super().__init__()
        self.from_file = TRANSFORMS.build(
            dict(type='LoadPointsFromFile', coord_type=coord_type, **kwargs))
        self.from_ndarray = TRANSFORMS.build(
            dict(type='LoadPointsFromDict', coord_type=coord_type, **kwargs))
        self.box_type_3d, self.box_mode_3d = get_box_type(coord_type)

    def transform(self, single_input: dict) -> dict:
        """Transform function to add image meta information.
        Args:
            single_input (dict): Single input.

        Returns:
            dict: The dict contains loaded image and meta information.
        """
        assert 'points' in single_input, "key 'points' must be in input dict"
        if isinstance(single_input['points'], str):
            inputs = dict(
                lidar_points=dict(lidar_path=single_input['points']),
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        elif isinstance(single_input['points'], np.ndarray):
            inputs = dict(
                points=single_input['points'],
                timestamp=1,
                # for ScanNet demo we need axis_align_matrix
                axis_align_matrix=np.eye(4),
                box_type_3d=self.box_type_3d,
                box_mode_3d=self.box_mode_3d)
        else:
            raise ValueError('Unsupported input points type: '
                             f"{type(single_input['points'])}")

        if 'points' in inputs:
            return self.from_ndarray(inputs)
        return self.from_file(inputs)


@TRANSFORMS.register_module()
class MonoDet3DInferencerLoader(BaseTransform):
1073
1074
1075
1076
    """Load an image from ``results['images']['CAMX']['img']``. Similar with
    :obj:`LoadImageFromFileMono3D`, but the image has been loaded as
    :obj:`np.ndarray` in ``results['images']['CAMX']['img']``.

1077
1078
1079
1080
1081
1082
    Added keys:
      - img
      - cam2img
      - box_type_3d
      - box_mode_3d

1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    """

    def __init__(self, **kwargs) -> None:
        super().__init__()
        self.from_file = TRANSFORMS.build(
            dict(type='LoadImageFromFileMono3D', **kwargs))
        self.from_ndarray = TRANSFORMS.build(
            dict(type='LoadImageFromNDArray', **kwargs))

    def transform(self, single_input: dict) -> dict:
        """Transform function to add image meta information.

        Args:
            single_input (dict): Result dict with Webcam read image in
                ``results['images']['CAMX']['img']``.
        Returns:
            dict: The dict contains loaded image and meta information.
        """
        box_type_3d, box_mode_3d = get_box_type('camera')
1102
1103
        assert 'calib' in single_input and 'img' in single_input, \
            "key 'calib' and 'img' must be in input dict"
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        if isinstance(single_input['calib'], str):
            calib_path = single_input['calib']
            with open(calib_path, 'r') as f:
                lines = f.readlines()
            cam2img = np.array([
                float(info) for info in lines[0].split(' ')[0:16]
            ]).reshape([4, 4])
        elif isinstance(single_input['calib'], np.ndarray):
            cam2img = single_input['calib']
        else:
1114
1115
            raise ValueError('Unsupported input calib type: '
                             f"{type(single_input['calib'])}")
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130

        if isinstance(single_input['img'], str):
            inputs = dict(
                images=dict(
                    CAM_FRONT=dict(
                        img_path=single_input['img'], cam2img=cam2img)),
                box_mode_3d=box_mode_3d,
                box_type_3d=box_type_3d)
        elif isinstance(single_input['img'], np.ndarray):
            inputs = dict(
                img=single_input['img'],
                cam2img=cam2img,
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
        else:
1131
1132
            raise ValueError('Unsupported input image type: '
                             f"{type(single_input['img'])}")
1133
1134
1135
1136

        if 'img' in inputs:
            return self.from_ndarray(inputs)
        return self.from_file(inputs)