utils.py 6.05 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import mmcv
VVsssssk's avatar
VVsssssk committed
3
import numpy as np
4
from mmcv.transforms import LoadImageFromFile
VVsssssk's avatar
VVsssssk committed
5
from pyquaternion import Quaternion
6

7
# yapf: disable
jshilong's avatar
jshilong committed
8
from mmdet3d.datasets.pipelines import (LoadAnnotations3D,
9
                                        LoadImageFromFileMono3D,
10
11
                                        LoadMultiViewImageFromFiles,
                                        LoadPointsFromFile,
12
                                        LoadPointsFromMultiSweeps,
jshilong's avatar
jshilong committed
13
                                        MultiScaleFlipAug3D, Pack3DDetInputs,
14
                                        PointSegClassMapping)
15
# yapf: enable
16
17
from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.pipelines import MultiScaleFlipAug
18
19


20
21
22
def is_loading_function(transform):
    """Judge whether a transform function is a loading function.

23
24
25
    Note: `MultiScaleFlipAug3D` is a wrapper for multiple pipeline functions,
    so we need to search if its inner transforms contain any loading function.

26
27
28
29
    Args:
        transform (dict | :obj:`Pipeline`): A transform config or a function.

    Returns:
30
        bool: Whether it is a loading function. None means can't judge.
31
32
33
34
35
            When transform is `MultiScaleFlipAug3D`, we return None.
    """
    # TODO: use more elegant way to distinguish loading modules
    loading_functions = (LoadImageFromFile, LoadPointsFromFile,
                         LoadAnnotations3D, LoadMultiViewImageFromFiles,
jshilong's avatar
jshilong committed
36
37
                         LoadPointsFromMultiSweeps, Pack3DDetInputs,
                         LoadImageFromFileMono3D, PointSegClassMapping)
38
    if isinstance(transform, dict):
39
        obj_cls = TRANSFORMS.get(transform['type'])
40
41
42
43
        if obj_cls is None:
            return False
        if obj_cls in loading_functions:
            return True
44
        if obj_cls in (MultiScaleFlipAug3D, MultiScaleFlipAug):
45
46
47
48
            return None
    elif callable(transform):
        if isinstance(transform, loading_functions):
            return True
49
        if isinstance(transform, (MultiScaleFlipAug3D, MultiScaleFlipAug)):
50
51
52
53
            return None
    return False


54
55
56
57
def get_loading_pipeline(pipeline):
    """Only keep loading image, points and annotations related configuration.

    Args:
58
59
        pipeline (list[dict] | list[:obj:`Pipeline`]):
            Data pipeline configs or list of pipeline functions.
60
61

    Returns:
62
63
        list[dict] | list[:obj:`Pipeline`]): The new pipeline list with only
            keep loading image, points and annotations related configuration.
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    Examples:
        >>> pipelines = [
        ...    dict(type='LoadPointsFromFile',
        ...         coord_type='LIDAR', load_dim=4, use_dim=4),
        ...    dict(type='LoadImageFromFile'),
        ...    dict(type='LoadAnnotations3D',
        ...         with_bbox=True, with_label_3d=True),
        ...    dict(type='Resize',
        ...         img_scale=[(640, 192), (2560, 768)], keep_ratio=True),
        ...    dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
        ...    dict(type='PointsRangeFilter',
        ...         point_cloud_range=point_cloud_range),
        ...    dict(type='ObjectRangeFilter',
        ...         point_cloud_range=point_cloud_range),
        ...    dict(type='PointShuffle'),
        ...    dict(type='Normalize', **img_norm_cfg),
        ...    dict(type='Pad', size_divisor=32),
        ...    dict(type='DefaultFormatBundle3D', class_names=class_names),
        ...    dict(type='Collect3D',
        ...         keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
        ...    ]
        >>> expected_pipelines = [
        ...    dict(type='LoadPointsFromFile',
        ...         coord_type='LIDAR', load_dim=4, use_dim=4),
        ...    dict(type='LoadImageFromFile'),
        ...    dict(type='LoadAnnotations3D',
        ...         with_bbox=True, with_label_3d=True),
        ...    dict(type='DefaultFormatBundle3D', class_names=class_names),
        ...    dict(type='Collect3D',
        ...         keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
        ...    ]
96
        >>> assert expected_pipelines == \
97
98
        ...        get_loading_pipeline(pipelines)
    """
99
100
101
102
103
104
105
106
107
108
109
110
111
    loading_pipeline = []
    for transform in pipeline:
        is_loading = is_loading_function(transform)
        if is_loading is None:  # MultiScaleFlipAug3D
            # extract its inner pipeline
            if isinstance(transform, dict):
                inner_pipeline = transform.get('transforms', [])
            else:
                inner_pipeline = transform.transforms.transforms
            loading_pipeline.extend(get_loading_pipeline(inner_pipeline))
        elif is_loading:
            loading_pipeline.append(transform)
    assert len(loading_pipeline) > 0, \
112
113
        'The data pipeline in your config file must include ' \
        'loading step.'
114
    return loading_pipeline
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129


def extract_result_dict(results, key):
    """Extract and return the data corresponding to key in result dict.

    ``results`` is a dict output from `pipeline(input_dict)`, which is the
        loaded data from ``Dataset`` class.
    The data terms inside may be wrapped in list, tuple and DataContainer, so
        this function essentially extracts data from these wrappers.

    Args:
        results (dict): Data loaded using pipeline.
        key (str): Key of the desired data.

    Returns:
130
        np.ndarray | torch.Tensor: Data term.
131
132
133
134
135
136
137
138
139
140
141
    """
    if key not in results.keys():
        return None
    # results[key] may be data or list[data] or tuple[data]
    # data may be wrapped inside DataContainer
    data = results[key]
    if isinstance(data, (list, tuple)):
        data = data[0]
    if isinstance(data, mmcv.parallel.DataContainer):
        data = data._data
    return data
VVsssssk's avatar
VVsssssk committed
142
143
144
145
146
147
148
149
150
151
152


def convert_quaternion_to_matrix(quaternion: list,
                                 translation: list = None) -> list:
    """Compute a transform matrix by given quaternion and translation
    vector."""
    result = np.eye(4)
    result[:3, :3] = Quaternion(quaternion).rotation_matrix
    if translation is not None:
        result[:3, 3] = np.array(translation)
    return result.astype(np.float32).tolist()