"vscode:/vscode.git/clone" did not exist on "84147254c91dd4c96b764dfb959322786bc7ab43"
formating.py 6.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
3
from typing import List, Union

zhangwenwei's avatar
zhangwenwei committed
4
import numpy as np
jshilong's avatar
jshilong committed
5
from mmcv import BaseTransform
6
from mmcv.transforms import to_tensor
jshilong's avatar
jshilong committed
7
from mmengine import InstanceData
zhangwenwei's avatar
zhangwenwei committed
8

jshilong's avatar
jshilong committed
9
from mmdet3d.core import Det3DDataSample
10
from mmdet3d.core.bbox import BaseInstance3DBoxes
11
from mmdet3d.core.points import BasePoints
12
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
13
14


15
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Pack3DDetInputs(BaseTransform):
    INPUTS_KEYS = ['points', 'img']
    INSTANCEDATA_3D_KEYS = [
        'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
    ]
    INSTANCEDATA_2D_KEYS = [
        'gt_bboxes',
        'gt_labels',
    ]

    SEG_KEYS = [
        'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
        'gt_semantic_seg'
    ]
zhangwenwei's avatar
zhangwenwei committed
30

jshilong's avatar
jshilong committed
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        keys: dict,
        meta_keys: dict = ('filename', 'ori_shape', 'img_shape', 'lidar2img',
                           'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
                           'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
                           'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
                           'pcd_trans', 'sample_idx', 'pcd_scale_factor',
                           'pcd_rotation', 'pcd_rotation_angle',
                           'pts_filename', 'transformation_3d_flow',
                           'trans_mat', 'affine_aug')):
        self.keys = keys
        self.meta_keys = meta_keys
zhangwenwei's avatar
zhangwenwei committed
44

jshilong's avatar
jshilong committed
45
46
47
48
    def _remove_prefix(self, key: str) -> str:
        if key.startswith('gt_'):
            key = key[3:]
        return key
zhangwenwei's avatar
zhangwenwei committed
49

jshilong's avatar
jshilong committed
50
51
52
53
    def transform(self, results: Union[dict,
                                       List[dict]]) -> Union[dict, List[dict]]:
        """Method to pack the input data. when the value in this dict is a
        list, it usually is in Augmentations Testing.
54
55

        Args:
jshilong's avatar
jshilong committed
56
            results (dict | list[dict]): Result dict from the data pipeline.
57
58

        Returns:
jshilong's avatar
jshilong committed
59
            dict | List[dict]:
jshilong's avatar
jshilong committed
60
61
62
63
64
65
66
67
68

            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
69
        """
jshilong's avatar
jshilong committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        # augtest
        if isinstance(results, list):
            pack_results = []
            for single_result in results:
                pack_results.append(self.pack_single_results(single_result))
            return pack_results
        # norm training and simple testing
        elif isinstance(results, dict):
            return self.pack_single_results(results)
        else:
            raise NotImplementedError

    def pack_single_results(self, results):
        """Method to pack the single input data. when the value in this dict is
        a list, it usually is in Augmentations Testing.

        Args:
            results (dict): Result dict from the data pipeline.

        Returns:
            dict: A dict contains
jshilong's avatar
jshilong committed
91

jshilong's avatar
jshilong committed
92
93
94
95
96
97
98
99
100
            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
        """
jshilong's avatar
jshilong committed
101
102
        # Format 3D data
        if 'points' in results:
jshilong's avatar
jshilong committed
103
104
            if isinstance(results['points'], BasePoints):
                results['points'] = results['points'].tensor
jshilong's avatar
jshilong committed
105

zhangwenwei's avatar
zhangwenwei committed
106
107
108
109
110
        if 'img' in results:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
jshilong's avatar
jshilong committed
111
                results['img'] = to_tensor(imgs)
zhangwenwei's avatar
zhangwenwei committed
112
            else:
jshilong's avatar
jshilong committed
113
114
115
116
117
                img = results['img']
                if len(img.shape) < 3:
                    img = np.expand_dims(img, -1)
                results['img'] = np.ascontiguousarray(img.transpose(2, 0, 1))

zhangwenwei's avatar
zhangwenwei committed
118
        for key in [
119
                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
120
121
                'gt_labels_3d', 'attr_labels', 'pts_instance_mask',
                'pts_semantic_mask', 'centers2d', 'depths'
zhangwenwei's avatar
zhangwenwei committed
122
123
124
125
        ]:
            if key not in results:
                continue
            if isinstance(results[key], list):
jshilong's avatar
jshilong committed
126
                results[key] = [to_tensor(res) for res in results[key]]
zhangwenwei's avatar
zhangwenwei committed
127
            else:
jshilong's avatar
jshilong committed
128
                results[key] = to_tensor(results[key])
129
        if 'gt_bboxes_3d' in results:
jshilong's avatar
jshilong committed
130
131
            if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
                results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
132

zhangwenwei's avatar
zhangwenwei committed
133
        if 'gt_semantic_seg' in results:
jshilong's avatar
jshilong committed
134
135
136
137
            results['gt_semantic_seg'] = to_tensor(
                results['gt_semantic_seg'][None])
        if 'gt_seg_map' in results:
            results['gt_seg_map'] = results['gt_seg_map'][None, ...]
wangtai's avatar
wangtai committed
138

jshilong's avatar
jshilong committed
139
140
141
142
        data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData()
        gt_instances = InstanceData()
        seg_data = dict()
zhangwenwei's avatar
zhangwenwei committed
143

zhangwenwei's avatar
zhangwenwei committed
144
        img_metas = {}
zhangwenwei's avatar
zhangwenwei committed
145
146
        for key in self.meta_keys:
            if key in results:
zhangwenwei's avatar
zhangwenwei committed
147
                img_metas[key] = results[key]
jshilong's avatar
jshilong committed
148
        data_sample.set_metainfo(img_metas)
149

jshilong's avatar
jshilong committed
150
        inputs = {}
zhangwenwei's avatar
zhangwenwei committed
151
        for key in self.keys:
jshilong's avatar
jshilong committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            if key in results:
                if key in self.INPUTS_KEYS:
                    inputs[key] = results[key]
                elif key in self.INSTANCEDATA_3D_KEYS:
                    gt_instances_3d[self._remove_prefix(key)] = results[key]
                elif key in self.INSTANCEDATA_2D_KEYS:
                    gt_instances[self._remove_prefix(key)] = results[key]
                elif key in self.SEG_KEYS:
                    seg_data[self._remove_prefix(key)] = results[key]
                else:
                    raise NotImplementedError(f'Please modified '
                                              f'`Pack3DDetInputs` '
                                              f'to put {key} to '
                                              f'corresponding field')

        data_sample.gt_instances_3d = gt_instances_3d
        data_sample.gt_instances = gt_instances
        data_sample.seg_data = seg_data
jshilong's avatar
jshilong committed
170
171
172
173
174
175
        if 'eval_ann_info' in results:
            data_sample.eval_ann_info = results['eval_ann_info']
        else:
            data_sample.eval_ann_info = None

        packed_results = dict()
jshilong's avatar
jshilong committed
176
177
178
179
180
181
        packed_results['data_sample'] = data_sample
        packed_results['inputs'] = inputs

        return packed_results

    def __repr__(self) -> str:
zhangwenwei's avatar
zhangwenwei committed
182
        repr_str = self.__class__.__name__
jshilong's avatar
jshilong committed
183
184
        repr_str += f'(keys={self.keys})'
        repr_str += f'(meta_keys={self.meta_keys})'
zhangwenwei's avatar
zhangwenwei committed
185
        return repr_str