utils.py 5.97 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import mmcv
VVsssssk's avatar
VVsssssk committed
3
import numpy as np
4
from mmcv.transforms import LoadImageFromFile
VVsssssk's avatar
VVsssssk committed
5
from pyquaternion import Quaternion
6

7
# yapf: disable
zhangshilong's avatar
zhangshilong committed
8
9
10
11
12
13
14
from mmdet3d.datasets.transforms import (LoadAnnotations3D,
                                         LoadImageFromFileMono3D,
                                         LoadMultiViewImageFromFiles,
                                         LoadPointsFromFile,
                                         LoadPointsFromMultiSweeps,
                                         MultiScaleFlipAug3D, Pack3DDetInputs,
                                         PointSegClassMapping)
15
# yapf: enable
16
from mmdet3d.registry import TRANSFORMS
17
18


19
20
21
def is_loading_function(transform):
    """Judge whether a transform function is a loading function.

22
23
24
    Note: `MultiScaleFlipAug3D` is a wrapper for multiple pipeline functions,
    so we need to search if its inner transforms contain any loading function.

25
26
27
28
    Args:
        transform (dict | :obj:`Pipeline`): A transform config or a function.

    Returns:
29
        bool: Whether it is a loading function. None means can't judge.
30
31
32
33
34
            When transform is `MultiScaleFlipAug3D`, we return None.
    """
    # TODO: use more elegant way to distinguish loading modules
    loading_functions = (LoadImageFromFile, LoadPointsFromFile,
                         LoadAnnotations3D, LoadMultiViewImageFromFiles,
jshilong's avatar
jshilong committed
35
36
                         LoadPointsFromMultiSweeps, Pack3DDetInputs,
                         LoadImageFromFileMono3D, PointSegClassMapping)
37
    if isinstance(transform, dict):
38
        obj_cls = TRANSFORMS.get(transform['type'])
39
40
41
42
        if obj_cls is None:
            return False
        if obj_cls in loading_functions:
            return True
zhangshilong's avatar
zhangshilong committed
43
        if obj_cls in (MultiScaleFlipAug3D, ):
44
45
46
47
            return None
    elif callable(transform):
        if isinstance(transform, loading_functions):
            return True
zhangshilong's avatar
zhangshilong committed
48
        if isinstance(transform, (MultiScaleFlipAug3D)):
49
50
51
52
            return None
    return False


53
54
55
56
def get_loading_pipeline(pipeline):
    """Only keep loading image, points and annotations related configuration.

    Args:
57
58
        pipeline (list[dict] | list[:obj:`Pipeline`]):
            Data pipeline configs or list of pipeline functions.
59
60

    Returns:
61
62
        list[dict] | list[:obj:`Pipeline`]): The new pipeline list with only
            keep loading image, points and annotations related configuration.
63
64

    Examples:
zhangshilong's avatar
zhangshilong committed
65
        >>> transforms = [
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        ...    dict(type='LoadPointsFromFile',
        ...         coord_type='LIDAR', load_dim=4, use_dim=4),
        ...    dict(type='LoadImageFromFile'),
        ...    dict(type='LoadAnnotations3D',
        ...         with_bbox=True, with_label_3d=True),
        ...    dict(type='Resize',
        ...         img_scale=[(640, 192), (2560, 768)], keep_ratio=True),
        ...    dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
        ...    dict(type='PointsRangeFilter',
        ...         point_cloud_range=point_cloud_range),
        ...    dict(type='ObjectRangeFilter',
        ...         point_cloud_range=point_cloud_range),
        ...    dict(type='PointShuffle'),
        ...    dict(type='Normalize', **img_norm_cfg),
        ...    dict(type='Pad', size_divisor=32),
        ...    dict(type='DefaultFormatBundle3D', class_names=class_names),
        ...    dict(type='Collect3D',
        ...         keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
        ...    ]
        >>> expected_pipelines = [
        ...    dict(type='LoadPointsFromFile',
        ...         coord_type='LIDAR', load_dim=4, use_dim=4),
        ...    dict(type='LoadImageFromFile'),
        ...    dict(type='LoadAnnotations3D',
        ...         with_bbox=True, with_label_3d=True),
        ...    dict(type='DefaultFormatBundle3D', class_names=class_names),
        ...    dict(type='Collect3D',
        ...         keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
        ...    ]
95
        >>> assert expected_pipelines == \
zhangshilong's avatar
zhangshilong committed
96
        ...        get_loading_pipeline(transforms)
97
    """
98
99
100
101
102
103
104
105
106
107
108
109
110
    loading_pipeline = []
    for transform in pipeline:
        is_loading = is_loading_function(transform)
        if is_loading is None:  # MultiScaleFlipAug3D
            # extract its inner pipeline
            if isinstance(transform, dict):
                inner_pipeline = transform.get('transforms', [])
            else:
                inner_pipeline = transform.transforms.transforms
            loading_pipeline.extend(get_loading_pipeline(inner_pipeline))
        elif is_loading:
            loading_pipeline.append(transform)
    assert len(loading_pipeline) > 0, \
111
112
        'The data pipeline in your config file must include ' \
        'loading step.'
113
    return loading_pipeline
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


def extract_result_dict(results, key):
    """Extract and return the data corresponding to key in result dict.

    ``results`` is a dict output from `pipeline(input_dict)`, which is the
        loaded data from ``Dataset`` class.
    The data terms inside may be wrapped in list, tuple and DataContainer, so
        this function essentially extracts data from these wrappers.

    Args:
        results (dict): Data loaded using pipeline.
        key (str): Key of the desired data.

    Returns:
129
        np.ndarray | torch.Tensor: Data term.
130
131
132
133
134
135
136
137
138
139
140
    """
    if key not in results.keys():
        return None
    # results[key] may be data or list[data] or tuple[data]
    # data may be wrapped inside DataContainer
    data = results[key]
    if isinstance(data, (list, tuple)):
        data = data[0]
    if isinstance(data, mmcv.parallel.DataContainer):
        data = data._data
    return data
VVsssssk's avatar
VVsssssk committed
141
142
143
144
145
146
147
148
149
150
151


def convert_quaternion_to_matrix(quaternion: list,
                                 translation: list = None) -> list:
    """Compute a transform matrix by given quaternion and translation
    vector."""
    result = np.eye(4)
    result[:3, :3] = Quaternion(quaternion).rotation_matrix
    if translation is not None:
        result[:3, 3] = np.array(translation)
    return result.astype(np.float32).tolist()