formating.py 9.92 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
from typing import List, Sequence, Union
jshilong's avatar
jshilong committed
3

4
import mmengine
zhangwenwei's avatar
zhangwenwei committed
5
import numpy as np
zhangshilong's avatar
zhangshilong committed
6
import torch
jshilong's avatar
jshilong committed
7
from mmcv import BaseTransform
8
from mmengine.structures import InstanceData
zhangshilong's avatar
zhangshilong committed
9
from numpy import dtype
zhangwenwei's avatar
zhangwenwei committed
10

11
from mmdet3d.registry import TRANSFORMS
zhangshilong's avatar
zhangshilong committed
12
13
from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample, PointData
from mmdet3d.structures.points import BasePoints
zhangwenwei's avatar
zhangwenwei committed
14
15


zhangshilong's avatar
zhangshilong committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def to_tensor(
    data: Union[torch.Tensor, np.ndarray, Sequence, int,
                float]) -> torch.Tensor:
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.

    Args:
        data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
            be converted.

    Returns:
        torch.Tensor: the converted data.
    """

    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        if data.dtype is dtype('float64'):
            data = data.astype(np.float32)
        return torch.from_numpy(data)
38
    elif isinstance(data, Sequence) and not mmengine.is_str(data):
zhangshilong's avatar
zhangshilong committed
39
40
41
42
43
44
45
46
47
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError(f'type {type(data)} cannot be converted to tensor.')


48
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
49
50
51
52
53
54
55
class Pack3DDetInputs(BaseTransform):
    INPUTS_KEYS = ['points', 'img']
    INSTANCEDATA_3D_KEYS = [
        'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
    ]
    INSTANCEDATA_2D_KEYS = [
        'gt_bboxes',
zhangshilong's avatar
zhangshilong committed
56
        'gt_bboxes_labels',
jshilong's avatar
jshilong committed
57
58
59
60
61
62
    ]

    SEG_KEYS = [
        'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
        'gt_semantic_seg'
    ]
zhangwenwei's avatar
zhangwenwei committed
63

jshilong's avatar
jshilong committed
64
65
    def __init__(
        self,
66
        keys: tuple,
67
68
69
70
71
72
73
74
        meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
                            'depth2img', 'cam2img', 'pad_shape',
                            'scale_factor', 'flip', 'pcd_horizontal_flip',
                            'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
                            'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
                            'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
                            'pcd_rotation_angle', 'lidar_path',
                            'transformation_3d_flow', 'trans_mat',
75
76
77
                            'affine_aug', 'sweep_img_metas', 'ori_cam2img',
                            'cam2global', 'crop_offset', 'img_crop_offset',
                            'resize_img_shape', 'lidar2cam', 'ori_lidar2img',
78
79
                            'num_ref_frames', 'num_views', 'ego2global',
                            'axis_align_matrix')
80
    ) -> None:
jshilong's avatar
jshilong committed
81
82
        self.keys = keys
        self.meta_keys = meta_keys
zhangwenwei's avatar
zhangwenwei committed
83

jshilong's avatar
jshilong committed
84
85
86
87
    def _remove_prefix(self, key: str) -> str:
        if key.startswith('gt_'):
            key = key[3:]
        return key
zhangwenwei's avatar
zhangwenwei committed
88

jshilong's avatar
jshilong committed
89
90
91
92
    def transform(self, results: Union[dict,
                                       List[dict]]) -> Union[dict, List[dict]]:
        """Method to pack the input data. when the value in this dict is a
        list, it usually is in Augmentations Testing.
93
94

        Args:
jshilong's avatar
jshilong committed
95
            results (dict | list[dict]): Result dict from the data pipeline.
96
97

        Returns:
jshilong's avatar
jshilong committed
98
            dict | List[dict]:
jshilong's avatar
jshilong committed
99
100
101
102
103
104
105

            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

106
            - 'data_samples' (:obj:`Det3DDataSample`): The annotation info of
107
              the sample.
108
        """
jshilong's avatar
jshilong committed
109
110
        # augtest
        if isinstance(results, list):
111
112
113
            if len(results) == 1:
                # simple test
                return self.pack_single_results(results[0])
jshilong's avatar
jshilong committed
114
115
116
117
118
119
120
121
122
123
            pack_results = []
            for single_result in results:
                pack_results.append(self.pack_single_results(single_result))
            return pack_results
        # norm training and simple testing
        elif isinstance(results, dict):
            return self.pack_single_results(results)
        else:
            raise NotImplementedError

124
    def pack_single_results(self, results: dict) -> dict:
jshilong's avatar
jshilong committed
125
126
127
128
129
130
131
132
        """Method to pack the single input data. when the value in this dict is
        a list, it usually is in Augmentations Testing.

        Args:
            results (dict): Result dict from the data pipeline.

        Returns:
            dict: A dict contains
jshilong's avatar
jshilong committed
133

jshilong's avatar
jshilong committed
134
135
136
137
138
139
            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

140
            - 'data_samples' (:obj:`Det3DDataSample`): The annotation info
141
              of the sample.
jshilong's avatar
jshilong committed
142
        """
jshilong's avatar
jshilong committed
143
144
        # Format 3D data
        if 'points' in results:
jshilong's avatar
jshilong committed
145
146
            if isinstance(results['points'], BasePoints):
                results['points'] = results['points'].tensor
jshilong's avatar
jshilong committed
147

zhangwenwei's avatar
zhangwenwei committed
148
149
150
        if 'img' in results:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
151
152
153
154
155
156
157
                imgs = np.stack(results['img'], axis=0)
                if imgs.flags.c_contiguous:
                    imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous()
                else:
                    imgs = to_tensor(
                        np.ascontiguousarray(imgs.transpose(0, 3, 1, 2)))
                results['img'] = imgs
zhangwenwei's avatar
zhangwenwei committed
158
            else:
jshilong's avatar
jshilong committed
159
160
161
                img = results['img']
                if len(img.shape) < 3:
                    img = np.expand_dims(img, -1)
162
163
164
165
                # To improve the computational speed by by 3-5 times, apply:
                # `torch.permute()` rather than `np.transpose()`.
                # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
                # for more details
166
167
168
169
170
171
                if img.flags.c_contiguous:
                    img = to_tensor(img).permute(2, 0, 1).contiguous()
                else:
                    img = to_tensor(
                        np.ascontiguousarray(img.transpose(2, 0, 1)))
                results['img'] = img
jshilong's avatar
jshilong committed
172

zhangwenwei's avatar
zhangwenwei committed
173
        for key in [
174
                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
zhangshilong's avatar
zhangshilong committed
175
176
                'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask',
                'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d'
zhangwenwei's avatar
zhangwenwei committed
177
178
179
180
        ]:
            if key not in results:
                continue
            if isinstance(results[key], list):
jshilong's avatar
jshilong committed
181
                results[key] = [to_tensor(res) for res in results[key]]
zhangwenwei's avatar
zhangwenwei committed
182
            else:
jshilong's avatar
jshilong committed
183
                results[key] = to_tensor(results[key])
184
        if 'gt_bboxes_3d' in results:
jshilong's avatar
jshilong committed
185
186
            if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
                results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
187

zhangwenwei's avatar
zhangwenwei committed
188
        if 'gt_semantic_seg' in results:
jshilong's avatar
jshilong committed
189
190
191
192
            results['gt_semantic_seg'] = to_tensor(
                results['gt_semantic_seg'][None])
        if 'gt_seg_map' in results:
            results['gt_seg_map'] = results['gt_seg_map'][None, ...]
wangtai's avatar
wangtai committed
193

jshilong's avatar
jshilong committed
194
195
196
        data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData()
        gt_instances = InstanceData()
ZCMax's avatar
ZCMax committed
197
        gt_pts_seg = PointData()
zhangwenwei's avatar
zhangwenwei committed
198

199
        data_metas = {}
zhangwenwei's avatar
zhangwenwei committed
200
201
        for key in self.meta_keys:
            if key in results:
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                data_metas[key] = results[key]
            elif 'images' in results:
                if len(results['images'].keys()) == 1:
                    cam_type = list(results['images'].keys())[0]
                    # single-view image
                    if key in results['images'][cam_type]:
                        data_metas[key] = results['images'][cam_type][key]
                else:
                    # multi-view image
                    img_metas = []
                    cam_types = list(results['images'].keys())
                    for cam_type in cam_types:
                        if key in results['images'][cam_type]:
                            img_metas.append(results['images'][cam_type][key])
                    if len(img_metas) > 0:
                        data_metas[key] = img_metas
            elif 'lidar_points' in results:
                if key in results['lidar_points']:
                    data_metas[key] = results['lidar_points'][key]
        data_sample.set_metainfo(data_metas)
222

jshilong's avatar
jshilong committed
223
        inputs = {}
zhangwenwei's avatar
zhangwenwei committed
224
        for key in self.keys:
jshilong's avatar
jshilong committed
225
226
227
228
229
230
            if key in results:
                if key in self.INPUTS_KEYS:
                    inputs[key] = results[key]
                elif key in self.INSTANCEDATA_3D_KEYS:
                    gt_instances_3d[self._remove_prefix(key)] = results[key]
                elif key in self.INSTANCEDATA_2D_KEYS:
zhangshilong's avatar
zhangshilong committed
231
232
233
234
                    if key == 'gt_bboxes_labels':
                        gt_instances['labels'] = results[key]
                    else:
                        gt_instances[self._remove_prefix(key)] = results[key]
jshilong's avatar
jshilong committed
235
                elif key in self.SEG_KEYS:
ZCMax's avatar
ZCMax committed
236
                    gt_pts_seg[self._remove_prefix(key)] = results[key]
jshilong's avatar
jshilong committed
237
238
239
240
241
242
243
244
                else:
                    raise NotImplementedError(f'Please modified '
                                              f'`Pack3DDetInputs` '
                                              f'to put {key} to '
                                              f'corresponding field')

        data_sample.gt_instances_3d = gt_instances_3d
        data_sample.gt_instances = gt_instances
ZCMax's avatar
ZCMax committed
245
        data_sample.gt_pts_seg = gt_pts_seg
jshilong's avatar
jshilong committed
246
247
248
249
250
251
        if 'eval_ann_info' in results:
            data_sample.eval_ann_info = results['eval_ann_info']
        else:
            data_sample.eval_ann_info = None

        packed_results = dict()
252
        packed_results['data_samples'] = data_sample
jshilong's avatar
jshilong committed
253
254
255
256
257
        packed_results['inputs'] = inputs

        return packed_results

    def __repr__(self) -> str:
258
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
259
        repr_str = self.__class__.__name__
jshilong's avatar
jshilong committed
260
261
        repr_str += f'(keys={self.keys})'
        repr_str += f'(meta_keys={self.meta_keys})'
zhangwenwei's avatar
zhangwenwei committed
262
        return repr_str