formating.py 6.77 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
3
from typing import List, Union

zhangwenwei's avatar
zhangwenwei committed
4
import numpy as np
jshilong's avatar
jshilong committed
5
from mmcv import BaseTransform
6
from mmcv.transforms import to_tensor
jshilong's avatar
jshilong committed
7
from mmengine import InstanceData
zhangwenwei's avatar
zhangwenwei committed
8

ZCMax's avatar
ZCMax committed
9
from mmdet3d.core import Det3DDataSample, PointData
10
from mmdet3d.core.bbox import BaseInstance3DBoxes
11
from mmdet3d.core.points import BasePoints
12
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
13
14


15
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Pack3DDetInputs(BaseTransform):
    INPUTS_KEYS = ['points', 'img']
    INSTANCEDATA_3D_KEYS = [
        'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
    ]
    INSTANCEDATA_2D_KEYS = [
        'gt_bboxes',
        'gt_labels',
    ]

    SEG_KEYS = [
        'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
        'gt_semantic_seg'
    ]
zhangwenwei's avatar
zhangwenwei committed
30

jshilong's avatar
jshilong committed
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        keys: dict,
        meta_keys: dict = ('filename', 'ori_shape', 'img_shape', 'lidar2img',
                           'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
                           'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
                           'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
                           'pcd_trans', 'sample_idx', 'pcd_scale_factor',
                           'pcd_rotation', 'pcd_rotation_angle',
                           'pts_filename', 'transformation_3d_flow',
                           'trans_mat', 'affine_aug')):
        self.keys = keys
        self.meta_keys = meta_keys
zhangwenwei's avatar
zhangwenwei committed
44

jshilong's avatar
jshilong committed
45
46
47
48
    def _remove_prefix(self, key: str) -> str:
        if key.startswith('gt_'):
            key = key[3:]
        return key
zhangwenwei's avatar
zhangwenwei committed
49

jshilong's avatar
jshilong committed
50
51
52
53
    def transform(self, results: Union[dict,
                                       List[dict]]) -> Union[dict, List[dict]]:
        """Method to pack the input data. when the value in this dict is a
        list, it usually is in Augmentations Testing.
54
55

        Args:
jshilong's avatar
jshilong committed
56
            results (dict | list[dict]): Result dict from the data pipeline.
57
58

        Returns:
jshilong's avatar
jshilong committed
59
            dict | List[dict]:
jshilong's avatar
jshilong committed
60
61
62
63
64
65
66
67
68

            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
69
        """
jshilong's avatar
jshilong committed
70
71
        # augtest
        if isinstance(results, list):
72
73
74
            if len(results) == 1:
                # simple test
                return self.pack_single_results(results[0])
jshilong's avatar
jshilong committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            pack_results = []
            for single_result in results:
                pack_results.append(self.pack_single_results(single_result))
            return pack_results
        # norm training and simple testing
        elif isinstance(results, dict):
            return self.pack_single_results(results)
        else:
            raise NotImplementedError

    def pack_single_results(self, results):
        """Method to pack the single input data. when the value in this dict is
        a list, it usually is in Augmentations Testing.

        Args:
            results (dict): Result dict from the data pipeline.

        Returns:
            dict: A dict contains
jshilong's avatar
jshilong committed
94

jshilong's avatar
jshilong committed
95
96
97
98
99
100
101
102
103
            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
        """
jshilong's avatar
jshilong committed
104
105
        # Format 3D data
        if 'points' in results:
jshilong's avatar
jshilong committed
106
107
            if isinstance(results['points'], BasePoints):
                results['points'] = results['points'].tensor
jshilong's avatar
jshilong committed
108

zhangwenwei's avatar
zhangwenwei committed
109
110
111
112
113
        if 'img' in results:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
jshilong's avatar
jshilong committed
114
                results['img'] = to_tensor(imgs)
zhangwenwei's avatar
zhangwenwei committed
115
            else:
jshilong's avatar
jshilong committed
116
117
118
                img = results['img']
                if len(img.shape) < 3:
                    img = np.expand_dims(img, -1)
119
120
                results['img'] = to_tensor(
                    np.ascontiguousarray(img.transpose(2, 0, 1)))
jshilong's avatar
jshilong committed
121

zhangwenwei's avatar
zhangwenwei committed
122
        for key in [
123
                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
124
125
                'gt_labels_3d', 'attr_labels', 'pts_instance_mask',
                'pts_semantic_mask', 'centers2d', 'depths'
zhangwenwei's avatar
zhangwenwei committed
126
127
128
129
        ]:
            if key not in results:
                continue
            if isinstance(results[key], list):
jshilong's avatar
jshilong committed
130
                results[key] = [to_tensor(res) for res in results[key]]
zhangwenwei's avatar
zhangwenwei committed
131
            else:
jshilong's avatar
jshilong committed
132
                results[key] = to_tensor(results[key])
133
        if 'gt_bboxes_3d' in results:
jshilong's avatar
jshilong committed
134
135
            if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
                results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
136

zhangwenwei's avatar
zhangwenwei committed
137
        if 'gt_semantic_seg' in results:
jshilong's avatar
jshilong committed
138
139
140
141
            results['gt_semantic_seg'] = to_tensor(
                results['gt_semantic_seg'][None])
        if 'gt_seg_map' in results:
            results['gt_seg_map'] = results['gt_seg_map'][None, ...]
wangtai's avatar
wangtai committed
142

jshilong's avatar
jshilong committed
143
144
145
        data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData()
        gt_instances = InstanceData()
ZCMax's avatar
ZCMax committed
146
        gt_pts_seg = PointData()
zhangwenwei's avatar
zhangwenwei committed
147

zhangwenwei's avatar
zhangwenwei committed
148
        img_metas = {}
zhangwenwei's avatar
zhangwenwei committed
149
150
        for key in self.meta_keys:
            if key in results:
zhangwenwei's avatar
zhangwenwei committed
151
                img_metas[key] = results[key]
jshilong's avatar
jshilong committed
152
        data_sample.set_metainfo(img_metas)
153

jshilong's avatar
jshilong committed
154
        inputs = {}
zhangwenwei's avatar
zhangwenwei committed
155
        for key in self.keys:
jshilong's avatar
jshilong committed
156
157
158
159
160
161
162
163
            if key in results:
                if key in self.INPUTS_KEYS:
                    inputs[key] = results[key]
                elif key in self.INSTANCEDATA_3D_KEYS:
                    gt_instances_3d[self._remove_prefix(key)] = results[key]
                elif key in self.INSTANCEDATA_2D_KEYS:
                    gt_instances[self._remove_prefix(key)] = results[key]
                elif key in self.SEG_KEYS:
ZCMax's avatar
ZCMax committed
164
                    gt_pts_seg[self._remove_prefix(key)] = results[key]
jshilong's avatar
jshilong committed
165
166
167
168
169
170
171
172
                else:
                    raise NotImplementedError(f'Please modified '
                                              f'`Pack3DDetInputs` '
                                              f'to put {key} to '
                                              f'corresponding field')

        data_sample.gt_instances_3d = gt_instances_3d
        data_sample.gt_instances = gt_instances
ZCMax's avatar
ZCMax committed
173
        data_sample.gt_pts_seg = gt_pts_seg
jshilong's avatar
jshilong committed
174
175
176
177
178
179
        if 'eval_ann_info' in results:
            data_sample.eval_ann_info = results['eval_ann_info']
        else:
            data_sample.eval_ann_info = None

        packed_results = dict()
jshilong's avatar
jshilong committed
180
181
182
183
184
185
        packed_results['data_sample'] = data_sample
        packed_results['inputs'] = inputs

        return packed_results

    def __repr__(self) -> str:
zhangwenwei's avatar
zhangwenwei committed
186
        repr_str = self.__class__.__name__
jshilong's avatar
jshilong committed
187
188
        repr_str += f'(keys={self.keys})'
        repr_str += f'(meta_keys={self.meta_keys})'
zhangwenwei's avatar
zhangwenwei committed
189
        return repr_str