"official/projects/token_dropping/train.py" did not exist on "8be7de9120350b93855591198b04747eb75823e5"
formating.py 7.98 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
from typing import List, Sequence, Union
jshilong's avatar
jshilong committed
3

zhangshilong's avatar
zhangshilong committed
4
import mmcv
zhangwenwei's avatar
zhangwenwei committed
5
import numpy as np
zhangshilong's avatar
zhangshilong committed
6
import torch
jshilong's avatar
jshilong committed
7
8
from mmcv import BaseTransform
from mmengine import InstanceData
zhangshilong's avatar
zhangshilong committed
9
from numpy import dtype
zhangwenwei's avatar
zhangwenwei committed
10

ZCMax's avatar
ZCMax committed
11
from mmdet3d.core import Det3DDataSample, PointData
12
from mmdet3d.core.bbox import BaseInstance3DBoxes
13
from mmdet3d.core.points import BasePoints
14
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
15
16


zhangshilong's avatar
zhangshilong committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def to_tensor(
    data: Union[torch.Tensor, np.ndarray, Sequence, int,
                float]) -> torch.Tensor:
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.

    Args:
        data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
            be converted.

    Returns:
        torch.Tensor: the converted data.
    """

    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        if data.dtype is dtype('float64'):
            data = data.astype(np.float32)
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmcv.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError(f'type {type(data)} cannot be converted to tensor.')


49
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
50
51
52
53
54
55
56
class Pack3DDetInputs(BaseTransform):
    INPUTS_KEYS = ['points', 'img']
    INSTANCEDATA_3D_KEYS = [
        'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
    ]
    INSTANCEDATA_2D_KEYS = [
        'gt_bboxes',
zhangshilong's avatar
zhangshilong committed
57
        'gt_bboxes_labels',
jshilong's avatar
jshilong committed
58
59
60
61
62
63
    ]

    SEG_KEYS = [
        'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
        'gt_semantic_seg'
    ]
zhangwenwei's avatar
zhangwenwei committed
64

jshilong's avatar
jshilong committed
65
66
67
    def __init__(
        self,
        keys: dict,
jshilong's avatar
jshilong committed
68
        meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
jshilong's avatar
jshilong committed
69
70
71
72
                           'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
                           'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
                           'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
                           'pcd_trans', 'sample_idx', 'pcd_scale_factor',
jshilong's avatar
jshilong committed
73
74
75
                           'pcd_rotation', 'pcd_rotation_angle', 'lidar_path',
                           'transformation_3d_flow', 'trans_mat',
                           'affine_aug')):
jshilong's avatar
jshilong committed
76
77
        self.keys = keys
        self.meta_keys = meta_keys
zhangwenwei's avatar
zhangwenwei committed
78

jshilong's avatar
jshilong committed
79
80
81
82
    def _remove_prefix(self, key: str) -> str:
        if key.startswith('gt_'):
            key = key[3:]
        return key
zhangwenwei's avatar
zhangwenwei committed
83

jshilong's avatar
jshilong committed
84
85
86
87
    def transform(self, results: Union[dict,
                                       List[dict]]) -> Union[dict, List[dict]]:
        """Method to pack the input data. when the value in this dict is a
        list, it usually is in Augmentations Testing.
88
89

        Args:
jshilong's avatar
jshilong committed
90
            results (dict | list[dict]): Result dict from the data pipeline.
91
92

        Returns:
jshilong's avatar
jshilong committed
93
            dict | List[dict]:
jshilong's avatar
jshilong committed
94
95
96
97
98
99
100
101
102

            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
103
        """
jshilong's avatar
jshilong committed
104
105
        # augtest
        if isinstance(results, list):
106
107
108
            if len(results) == 1:
                # simple test
                return self.pack_single_results(results[0])
jshilong's avatar
jshilong committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            pack_results = []
            for single_result in results:
                pack_results.append(self.pack_single_results(single_result))
            return pack_results
        # norm training and simple testing
        elif isinstance(results, dict):
            return self.pack_single_results(results)
        else:
            raise NotImplementedError

    def pack_single_results(self, results):
        """Method to pack the single input data. when the value in this dict is
        a list, it usually is in Augmentations Testing.

        Args:
            results (dict): Result dict from the data pipeline.

        Returns:
            dict: A dict contains
jshilong's avatar
jshilong committed
128

jshilong's avatar
jshilong committed
129
130
131
132
133
134
135
136
137
            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
        """
jshilong's avatar
jshilong committed
138
139
        # Format 3D data
        if 'points' in results:
jshilong's avatar
jshilong committed
140
141
            if isinstance(results['points'], BasePoints):
                results['points'] = results['points'].tensor
jshilong's avatar
jshilong committed
142

zhangwenwei's avatar
zhangwenwei committed
143
144
145
146
147
        if 'img' in results:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
jshilong's avatar
jshilong committed
148
                results['img'] = to_tensor(imgs)
zhangwenwei's avatar
zhangwenwei committed
149
            else:
jshilong's avatar
jshilong committed
150
151
152
                img = results['img']
                if len(img.shape) < 3:
                    img = np.expand_dims(img, -1)
153
154
                results['img'] = to_tensor(
                    np.ascontiguousarray(img.transpose(2, 0, 1)))
jshilong's avatar
jshilong committed
155

zhangwenwei's avatar
zhangwenwei committed
156
        for key in [
157
                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
zhangshilong's avatar
zhangshilong committed
158
159
                'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask',
                'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d'
zhangwenwei's avatar
zhangwenwei committed
160
161
162
163
        ]:
            if key not in results:
                continue
            if isinstance(results[key], list):
jshilong's avatar
jshilong committed
164
                results[key] = [to_tensor(res) for res in results[key]]
zhangwenwei's avatar
zhangwenwei committed
165
            else:
jshilong's avatar
jshilong committed
166
                results[key] = to_tensor(results[key])
167
        if 'gt_bboxes_3d' in results:
jshilong's avatar
jshilong committed
168
169
            if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
                results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
170

zhangwenwei's avatar
zhangwenwei committed
171
        if 'gt_semantic_seg' in results:
jshilong's avatar
jshilong committed
172
173
174
175
            results['gt_semantic_seg'] = to_tensor(
                results['gt_semantic_seg'][None])
        if 'gt_seg_map' in results:
            results['gt_seg_map'] = results['gt_seg_map'][None, ...]
wangtai's avatar
wangtai committed
176

jshilong's avatar
jshilong committed
177
178
179
        data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData()
        gt_instances = InstanceData()
ZCMax's avatar
ZCMax committed
180
        gt_pts_seg = PointData()
zhangwenwei's avatar
zhangwenwei committed
181

zhangwenwei's avatar
zhangwenwei committed
182
        img_metas = {}
zhangwenwei's avatar
zhangwenwei committed
183
184
        for key in self.meta_keys:
            if key in results:
zhangwenwei's avatar
zhangwenwei committed
185
                img_metas[key] = results[key]
jshilong's avatar
jshilong committed
186
        data_sample.set_metainfo(img_metas)
187

jshilong's avatar
jshilong committed
188
        inputs = {}
zhangwenwei's avatar
zhangwenwei committed
189
        for key in self.keys:
jshilong's avatar
jshilong committed
190
191
192
193
194
195
            if key in results:
                if key in self.INPUTS_KEYS:
                    inputs[key] = results[key]
                elif key in self.INSTANCEDATA_3D_KEYS:
                    gt_instances_3d[self._remove_prefix(key)] = results[key]
                elif key in self.INSTANCEDATA_2D_KEYS:
zhangshilong's avatar
zhangshilong committed
196
197
198
199
                    if key == 'gt_bboxes_labels':
                        gt_instances['labels'] = results[key]
                    else:
                        gt_instances[self._remove_prefix(key)] = results[key]
jshilong's avatar
jshilong committed
200
                elif key in self.SEG_KEYS:
ZCMax's avatar
ZCMax committed
201
                    gt_pts_seg[self._remove_prefix(key)] = results[key]
jshilong's avatar
jshilong committed
202
203
204
205
206
207
208
209
                else:
                    raise NotImplementedError(f'Please modified '
                                              f'`Pack3DDetInputs` '
                                              f'to put {key} to '
                                              f'corresponding field')

        data_sample.gt_instances_3d = gt_instances_3d
        data_sample.gt_instances = gt_instances
ZCMax's avatar
ZCMax committed
210
        data_sample.gt_pts_seg = gt_pts_seg
jshilong's avatar
jshilong committed
211
212
213
214
215
216
        if 'eval_ann_info' in results:
            data_sample.eval_ann_info = results['eval_ann_info']
        else:
            data_sample.eval_ann_info = None

        packed_results = dict()
jshilong's avatar
jshilong committed
217
218
219
220
221
222
        packed_results['data_sample'] = data_sample
        packed_results['inputs'] = inputs

        return packed_results

    def __repr__(self) -> str:
zhangwenwei's avatar
zhangwenwei committed
223
        repr_str = self.__class__.__name__
jshilong's avatar
jshilong committed
224
225
        repr_str += f'(keys={self.keys})'
        repr_str += f'(meta_keys={self.meta_keys})'
zhangwenwei's avatar
zhangwenwei committed
226
        return repr_str