# yapf: disable from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, LoadAnnotations3D, LoadImageFromFileMono3D, LoadMultiViewImageFromFiles, LoadPointsFromFile, LoadPointsFromMultiSweeps, MultiScaleFlipAug3D, PointSegClassMapping) # yapf: enable from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines import LoadImageFromFile def is_loading_function(transform): """Judge whether a transform function is a loading function. Args: transform (dict | :obj:`Pipeline`): A transform config or a function. Returns: bool | None: Whether it is a loading function. None means can't judge. When transform is `MultiScaleFlipAug3D`, we return None. """ # TODO: use more elegant way to distinguish loading modules loading_functions = (LoadImageFromFile, LoadPointsFromFile, LoadAnnotations3D, LoadMultiViewImageFromFiles, LoadPointsFromMultiSweeps, DefaultFormatBundle3D, Collect3D, LoadImageFromFileMono3D, PointSegClassMapping) if isinstance(transform, dict): obj_cls = PIPELINES.get(transform['type']) if obj_cls is None: return False if obj_cls in loading_functions: return True if obj_cls in (MultiScaleFlipAug3D): return None elif callable(transform): if isinstance(transform, loading_functions): return True if isinstance(transform, MultiScaleFlipAug3D): return None return False def get_loading_pipeline(pipeline): """Only keep loading image, points and annotations related configuration. Args: pipeline (list[dict] | list[:obj:`Pipeline`]): Data pipeline configs or list of pipeline functions. Returns: list[dict] | list[:obj:`Pipeline`]): The new pipeline list with only keep loading image, points and annotations related configuration. Examples: >>> pipelines = [ ... dict(type='LoadPointsFromFile', ... coord_type='LIDAR', load_dim=4, use_dim=4), ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations3D', ... with_bbox=True, with_label_3d=True), ... dict(type='Resize', ... img_scale=[(640, 192), (2560, 768)], keep_ratio=True), ... dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), ... dict(type='PointsRangeFilter', ... point_cloud_range=point_cloud_range), ... dict(type='ObjectRangeFilter', ... point_cloud_range=point_cloud_range), ... dict(type='PointShuffle'), ... dict(type='Normalize', **img_norm_cfg), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle3D', class_names=class_names), ... dict(type='Collect3D', ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... ] >>> expected_pipelines = [ ... dict(type='LoadPointsFromFile', ... coord_type='LIDAR', load_dim=4, use_dim=4), ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations3D', ... with_bbox=True, with_label_3d=True), ... dict(type='DefaultFormatBundle3D', class_names=class_names), ... dict(type='Collect3D', ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... ] >>> assert expected_pipelines ==\ ... get_loading_pipeline(pipelines) """ loading_pipeline = [] for transform in pipeline: is_loading = is_loading_function(transform) if is_loading is None: # MultiScaleFlipAug3D # extract its inner pipeline if isinstance(transform, dict): inner_pipeline = transform.get('transforms', []) else: inner_pipeline = transform.transforms.transforms loading_pipeline.extend(get_loading_pipeline(inner_pipeline)) elif is_loading: loading_pipeline.append(transform) assert len(loading_pipeline) > 0, \ 'The data pipeline in your config file must include ' \ 'loading step.' return loading_pipeline