formating.py 5.23 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import numpy as np
jshilong's avatar
jshilong committed
3
from mmcv import BaseTransform
4
from mmcv.transforms import to_tensor
jshilong's avatar
jshilong committed
5
from mmengine import InstanceData
zhangwenwei's avatar
zhangwenwei committed
6

jshilong's avatar
jshilong committed
7
from mmdet3d.core import Det3DDataSample
8
from mmdet3d.core.bbox import BaseInstance3DBoxes
9
from mmdet3d.core.points import BasePoints
10
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Pack3DDetInputs(BaseTransform):
    INPUTS_KEYS = ['points', 'img']
    INSTANCEDATA_3D_KEYS = [
        'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
    ]
    INSTANCEDATA_2D_KEYS = [
        'gt_bboxes',
        'gt_labels',
    ]

    SEG_KEYS = [
        'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
        'gt_semantic_seg'
    ]
zhangwenwei's avatar
zhangwenwei committed
28

jshilong's avatar
jshilong committed
29
30
31
32
33
34
35
36
37
38
39
40
41
    def __init__(
        self,
        keys: dict,
        meta_keys: dict = ('filename', 'ori_shape', 'img_shape', 'lidar2img',
                           'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
                           'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
                           'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
                           'pcd_trans', 'sample_idx', 'pcd_scale_factor',
                           'pcd_rotation', 'pcd_rotation_angle',
                           'pts_filename', 'transformation_3d_flow',
                           'trans_mat', 'affine_aug')):
        self.keys = keys
        self.meta_keys = meta_keys
zhangwenwei's avatar
zhangwenwei committed
42

jshilong's avatar
jshilong committed
43
44
45
46
    def _remove_prefix(self, key: str) -> str:
        if key.startswith('gt_'):
            key = key[3:]
        return key
zhangwenwei's avatar
zhangwenwei committed
47

jshilong's avatar
jshilong committed
48
49
    def transform(self, results: dict) -> dict:
        """Method to pack the input data.
50
51

        Args:
jshilong's avatar
jshilong committed
52
            results (dict): Result dict from the data pipeline.
53
54

        Returns:
jshilong's avatar
jshilong committed
55
56
57
58
59
60
61
62
63
64
            dict:

            - 'inputs' (dict): The forward data of models. It usually contains
              following keys:

                - points
                - img

            - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
              sample.
65
        """
jshilong's avatar
jshilong committed
66
67
68
69
70
71
72
        packed_results = dict()

        # Format 3D data
        if 'points' in results:
            assert isinstance(results['points'], BasePoints)
            results['points'] = results['points'].tensor

zhangwenwei's avatar
zhangwenwei committed
73
74
75
76
77
        if 'img' in results:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
jshilong's avatar
jshilong committed
78
                results['img'] = to_tensor(imgs)
zhangwenwei's avatar
zhangwenwei committed
79
            else:
jshilong's avatar
jshilong committed
80
81
82
83
84
                img = results['img']
                if len(img.shape) < 3:
                    img = np.expand_dims(img, -1)
                results['img'] = np.ascontiguousarray(img.transpose(2, 0, 1))

zhangwenwei's avatar
zhangwenwei committed
85
        for key in [
86
                'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
87
88
                'gt_labels_3d', 'attr_labels', 'pts_instance_mask',
                'pts_semantic_mask', 'centers2d', 'depths'
zhangwenwei's avatar
zhangwenwei committed
89
90
91
92
        ]:
            if key not in results:
                continue
            if isinstance(results[key], list):
jshilong's avatar
jshilong committed
93
                results[key] = [to_tensor(res) for res in results[key]]
zhangwenwei's avatar
zhangwenwei committed
94
            else:
jshilong's avatar
jshilong committed
95
                results[key] = to_tensor(results[key])
96
        if 'gt_bboxes_3d' in results:
jshilong's avatar
jshilong committed
97
98
            if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
                results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
99

zhangwenwei's avatar
zhangwenwei committed
100
        if 'gt_semantic_seg' in results:
jshilong's avatar
jshilong committed
101
102
103
104
            results['gt_semantic_seg'] = to_tensor(
                results['gt_semantic_seg'][None])
        if 'gt_seg_map' in results:
            results['gt_seg_map'] = results['gt_seg_map'][None, ...]
wangtai's avatar
wangtai committed
105

jshilong's avatar
jshilong committed
106
107
108
109
        data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData()
        gt_instances = InstanceData()
        seg_data = dict()
zhangwenwei's avatar
zhangwenwei committed
110

zhangwenwei's avatar
zhangwenwei committed
111
        img_metas = {}
zhangwenwei's avatar
zhangwenwei committed
112
113
        for key in self.meta_keys:
            if key in results:
zhangwenwei's avatar
zhangwenwei committed
114
                img_metas[key] = results[key]
jshilong's avatar
jshilong committed
115
        data_sample.set_metainfo(img_metas)
116

jshilong's avatar
jshilong committed
117
        inputs = {}
zhangwenwei's avatar
zhangwenwei committed
118
        for key in self.keys:
jshilong's avatar
jshilong committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            if key in results:
                if key in self.INPUTS_KEYS:
                    inputs[key] = results[key]
                elif key in self.INSTANCEDATA_3D_KEYS:
                    gt_instances_3d[self._remove_prefix(key)] = results[key]
                elif key in self.INSTANCEDATA_2D_KEYS:
                    gt_instances[self._remove_prefix(key)] = results[key]
                elif key in self.SEG_KEYS:
                    seg_data[self._remove_prefix(key)] = results[key]
                else:
                    raise NotImplementedError(f'Please modified '
                                              f'`Pack3DDetInputs` '
                                              f'to put {key} to '
                                              f'corresponding field')

        data_sample.gt_instances_3d = gt_instances_3d
        data_sample.gt_instances = gt_instances
        data_sample.seg_data = seg_data
        packed_results['data_sample'] = data_sample
        packed_results['inputs'] = inputs

        return packed_results

    def __repr__(self) -> str:
zhangwenwei's avatar
zhangwenwei committed
143
        repr_str = self.__class__.__name__
jshilong's avatar
jshilong committed
144
145
        repr_str += f'(keys={self.keys})'
        repr_str += f'(meta_keys={self.meta_keys})'
zhangwenwei's avatar
zhangwenwei committed
146
        return repr_str