data_augmentor.py 13 KB
Newer Older
1
from functools import partial
Shaoshuai Shi's avatar
Shaoshuai Shi committed
2

3
import numpy as np
4
from PIL import Image
Shaoshuai Shi's avatar
Shaoshuai Shi committed
5

6
from ...utils import common_utils
Shaoshuai Shi's avatar
Shaoshuai Shi committed
7
from . import augmentor_utils, database_sampler
8
9
10
11
12
13
14


class DataAugmentor(object):
    def __init__(self, root_path, augmentor_configs, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.logger = logger
15

16
        self.data_augmentor_queue = []
17
18
        aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
            else augmentor_configs.AUG_CONFIG_LIST
19

20
21
22
23
        for cur_cfg in aug_config_list:
            if not isinstance(augmentor_configs, list):
                if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
                    continue
24
            cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
25
            self.data_augmentor_queue.append(cur_augmentor)
26

27
    def disable_augmentation(self, augmentor_configs):
28
29
30
31
32
33
34
35
36
37
38
        self.data_augmentor_queue = []
        aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
            else augmentor_configs.AUG_CONFIG_LIST

        for cur_cfg in aug_config_list:
            if not isinstance(augmentor_configs, list):
                if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
                    continue
            cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
            self.data_augmentor_queue.append(cur_augmentor)
             
39
40
41
42
43
44
45
46
    def gt_sampling(self, config=None):
        db_sampler = database_sampler.DataBaseSampler(
            root_path=self.root_path,
            sampler_cfg=config,
            class_names=self.class_names,
            logger=self.logger
        )
        return db_sampler
47

Gus-Guo's avatar
Gus-Guo committed
48
49
50
51
    def __getstate__(self):
        d = dict(self.__dict__)
        del d['logger']
        return d
52

Gus-Guo's avatar
Gus-Guo committed
53
54
    def __setstate__(self, d):
        self.__dict__.update(d)
55

56
    def random_world_flip(self, data_dict=None, config=None):
Gus-Guo's avatar
Gus-Guo committed
57
58
        if data_dict is None:
            return partial(self.random_world_flip, config=config)
59
60
61
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for cur_axis in config['ALONG_AXIS_LIST']:
            assert cur_axis in ['x', 'y']
yukang.chen's avatar
yukang.chen committed
62
63
            gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                gt_boxes, points, return_flip=True
64
            )
yukang.chen's avatar
yukang.chen committed
65
            data_dict['flip_%s'%cur_axis] = enable
66
67
68
69
70
71
            if 'roi_boxes' in data_dict.keys():
                num_frame, num_rois,dim = data_dict['roi_boxes'].shape
                roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable
                )
                data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)
yukang.chen's avatar
yukang.chen committed
72

73
74
75
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
76

77
    def random_world_rotation(self, data_dict=None, config=None):
Gus-Guo's avatar
Gus-Guo committed
78
79
        if data_dict is None:
            return partial(self.random_world_rotation, config=config)
80
81
82
        rot_range = config['WORLD_ROT_ANGLE']
        if not isinstance(rot_range, list):
            rot_range = [-rot_range, rot_range]
yukang.chen's avatar
yukang.chen committed
83
84
        gt_boxes, points, noise_rot = augmentor_utils.global_rotation(
            data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True
85
        )
86
87
88
89
90
        if 'roi_boxes' in data_dict.keys():
            num_frame, num_rois,dim = data_dict['roi_boxes'].shape
            roi_boxes, _, _ = augmentor_utils.global_rotation(
            data_dict['roi_boxes'].reshape(-1, dim), np.zeros([1, 3]), rot_range=rot_range, return_rot=True, noise_rotation=noise_rot)
            data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)
91

92
93
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
yukang.chen's avatar
yukang.chen committed
94
        data_dict['noise_rot'] = noise_rot
95
        return data_dict
96

97
    def random_world_scaling(self, data_dict=None, config=None):
Gus-Guo's avatar
Gus-Guo committed
98
99
        if data_dict is None:
            return partial(self.random_world_scaling, config=config)
100
101
102
103
104
105
106
107
108
109
        
        if 'roi_boxes' in data_dict.keys():
            gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes(
                data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
            )
            data_dict['roi_boxes'] = roi_boxes
        else:
            gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
                data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
            )
110

111
112
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
yukang.chen's avatar
yukang.chen committed
113
        data_dict['noise_scale'] = noise_scale
114
        return data_dict
115

116
117
118
119
120
121
122
123
124
125
126
127
128
    def random_image_flip(self, data_dict=None, config=None):
        if data_dict is None:
            return partial(self.random_image_flip, config=config)
        images = data_dict["images"]
        depth_maps = data_dict["depth_maps"]
        gt_boxes = data_dict['gt_boxes']
        gt_boxes2d = data_dict["gt_boxes2d"]
        calib = data_dict["calib"]
        for cur_axis in config['ALONG_AXIS_LIST']:
            assert cur_axis in ['horizontal']
            images, depth_maps, gt_boxes = getattr(augmentor_utils, 'random_image_flip_%s' % cur_axis)(
                images, depth_maps, gt_boxes, calib,
            )
129

130
131
132
133
        data_dict['images'] = images
        data_dict['depth_maps'] = depth_maps
        data_dict['gt_boxes'] = gt_boxes
        return data_dict
134

135
136
    def random_world_translation(self, data_dict=None, config=None):
        if data_dict is None:
137
            return partial(self.random_world_translation, config=config)
138
        noise_translate_std = config['NOISE_TRANSLATE_STD']
139
140
141
142
143
144
        assert len(noise_translate_std) == 3
        noise_translate = np.array([
            np.random.normal(0, noise_translate_std[0], 1),
            np.random.normal(0, noise_translate_std[1], 1),
            np.random.normal(0, noise_translate_std[2], 1),
        ], dtype=np.float32).T
145

146
147
148
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        points[:, :3] += noise_translate
        gt_boxes[:, :3] += noise_translate
149
150
151
152
                
        if 'roi_boxes' in data_dict.keys():
            data_dict['roi_boxes'][:, :3] += noise_translate
        
153
154
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
155
        data_dict['noise_translate'] = noise_translate
156
        return data_dict
157

158
159
160
161
162
163
164
165
166
167
168
169
170
    def random_local_translation(self, data_dict=None, config=None):
        """
        Please check the correctness of it before using.
        """
        if data_dict is None:
            return partial(self.random_local_translation, config=config)
        offset_range = config['LOCAL_TRANSLATION_RANGE']
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for cur_axis in config['ALONG_AXIS_LIST']:
            assert cur_axis in ['x', 'y', 'z']
            gt_boxes, points = getattr(augmentor_utils, 'random_local_translation_along_%s' % cur_axis)(
                gt_boxes, points, offset_range,
            )
171

172
173
174
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
175

176
177
178
179
180
181
182
183
184
185
186
187
    def random_local_rotation(self, data_dict=None, config=None):
        """
        Please check the correctness of it before using.
        """
        if data_dict is None:
            return partial(self.random_local_rotation, config=config)
        rot_range = config['LOCAL_ROT_ANGLE']
        if not isinstance(rot_range, list):
            rot_range = [-rot_range, rot_range]
        gt_boxes, points = augmentor_utils.local_rotation(
            data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range
        )
188

189
190
191
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
192

193
194
195
196
197
198
199
200
201
    def random_local_scaling(self, data_dict=None, config=None):
        """
        Please check the correctness of it before using.
        """
        if data_dict is None:
            return partial(self.random_local_scaling, config=config)
        gt_boxes, points = augmentor_utils.local_scaling(
            data_dict['gt_boxes'], data_dict['points'], config['LOCAL_SCALE_RANGE']
        )
202

203
204
205
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
206

207
208
209
210
211
212
    def random_world_frustum_dropout(self, data_dict=None, config=None):
        """
        Please check the correctness of it before using.
        """
        if data_dict is None:
            return partial(self.random_world_frustum_dropout, config=config)
213

214
215
216
217
218
219
220
        intensity_range = config['INTENSITY_RANGE']
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for direction in config['DIRECTION']:
            assert direction in ['top', 'bottom', 'left', 'right']
            gt_boxes, points = getattr(augmentor_utils, 'global_frustum_dropout_%s' % direction)(
                gt_boxes, points, intensity_range,
            )
221

222
223
224
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
225

226
227
228
229
230
231
    def random_local_frustum_dropout(self, data_dict=None, config=None):
        """
        Please check the correctness of it before using.
        """
        if data_dict is None:
            return partial(self.random_local_frustum_dropout, config=config)
232

233
234
235
236
237
238
239
        intensity_range = config['INTENSITY_RANGE']
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for direction in config['DIRECTION']:
            assert direction in ['top', 'bottom', 'left', 'right']
            gt_boxes, points = getattr(augmentor_utils, 'local_frustum_dropout_%s' % direction)(
                gt_boxes, points, intensity_range,
            )
240

241
242
243
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
244

245
246
    def random_local_pyramid_aug(self, data_dict=None, config=None):
        """
247
        Refer to the paper:
248
249
250
251
            SE-SSD: Self-Ensembling Single-Stage Object Detector From Point Cloud
        """
        if data_dict is None:
            return partial(self.random_local_pyramid_aug, config=config)
252

253
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
254

255
256
257
258
259
260
261
262
263
264
265
266
        gt_boxes, points, pyramids = augmentor_utils.local_pyramid_dropout(gt_boxes, points, config['DROP_PROB'])
        gt_boxes, points, pyramids = augmentor_utils.local_pyramid_sparsify(gt_boxes, points,
                                                                            config['SPARSIFY_PROB'],
                                                                            config['SPARSIFY_MAX_NUM'],
                                                                            pyramids)
        gt_boxes, points = augmentor_utils.local_pyramid_swap(gt_boxes, points,
                                                                 config['SWAP_PROB'],
                                                                 config['SWAP_MAX_NUM'],
                                                                 pyramids)
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def imgaug(self, data_dict=None, config=None):
        if data_dict is None:
            return partial(self.imgaug, config=config)
        imgs = data_dict["camera_imgs"]
        img_process_infos = data_dict['img_process_infos']
        new_imgs = []
        for img, img_process_info in zip(imgs, img_process_infos):
            flip = False
            if config.RAND_FLIP and np.random.choice([0, 1]):
                flip = True
            rotate = np.random.uniform(*config.ROT_LIM)
            # aug images
            if flip:
                img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
            img = img.rotate(rotate)
            img_process_info[2] = flip
            img_process_info[3] = rotate
            new_imgs.append(img)

        data_dict["camera_imgs"] = new_imgs
        return data_dict

290
291
292
293
294
295
296
297
298
299
300
301
302
    def forward(self, data_dict):
        """
        Args:
            data_dict:
                points: (N, 3 + C_in)
                gt_boxes: optional, (N, 7) [x, y, z, dx, dy, dz, heading]
                gt_names: optional, (N), string
                ...

        Returns:
        """
        for cur_augmentor in self.data_augmentor_queue:
            data_dict = cur_augmentor(data_dict=data_dict)
303

304
305
306
        data_dict['gt_boxes'][:, 6] = common_utils.limit_period(
            data_dict['gt_boxes'][:, 6], offset=0.5, period=2 * np.pi
        )
307
308
        # if 'calib' in data_dict:
        #     data_dict.pop('calib')
309
310
311
312
313
314
        if 'road_plane' in data_dict:
            data_dict.pop('road_plane')
        if 'gt_boxes_mask' in data_dict:
            gt_boxes_mask = data_dict['gt_boxes_mask']
            data_dict['gt_boxes'] = data_dict['gt_boxes'][gt_boxes_mask]
            data_dict['gt_names'] = data_dict['gt_names'][gt_boxes_mask]
315
316
            if 'gt_boxes2d' in data_dict:
                data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][gt_boxes_mask]
317

318
            data_dict.pop('gt_boxes_mask')
319
        return data_dict