mosaic.py 16 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mosaic op."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
16
import random
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17

Vishnu Banna's avatar
config  
Vishnu Banna committed
18
19
20
import tensorflow as tf
import tensorflow_addons as tfa

Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
from official.projects.yolo.ops import preprocessing_ops
Abdullah Rashwan's avatar
Abdullah Rashwan committed
22
23
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
24

Vishnu Banna's avatar
config  
Vishnu Banna committed
25

Vishnu Banna's avatar
Vishnu Banna committed
26
class Mosaic:
Vishnu Banna's avatar
config  
Vishnu Banna committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  """Stitch together sets of 4 images to generate samples with more boxes."""

  def __init__(self,
               output_size,
               mosaic_frequency=1.0,
               mixup_frequency=0.0,
               letter_box=True,
               jitter=0.0,
               mosaic_crop_mode='scale',
               mosaic_center=0.25,
               aug_scale_min=1.0,
               aug_scale_max=1.0,
               aug_rand_angle=0.0,
               aug_rand_perspective=0.0,
               aug_rand_translate=0.0,
               random_pad=False,
43
               random_flip=False,
Vishnu Banna's avatar
config  
Vishnu Banna committed
44
               area_thresh=0.1,
45
               pad_value=preprocessing_ops.PAD_VALUE,
Vishnu Banna's avatar
config  
Vishnu Banna committed
46
47
48
49
50
51
52
               seed=None):
    """Initializes parameters for mosaic.

    Args:
      output_size: `Tensor` or `List` for [height, width] of output image.
      mosaic_frequency: `float` indicating how often to apply mosaic.
      mixup_frequency: `float` indicating how often to apply mixup.
53
54
55
56
57
      letter_box: `boolean` indicating whether upon start of the datapipeline
        regardless of the preprocessing ops that are used, the aspect ratio of
        the images should be preserved.
      jitter: `float` for the maximum change in aspect ratio expected in each
        preprocessing step.
58
      mosaic_crop_mode: `str` the type of mosaic to apply. The options are
59
60
61
62
63
        {crop, scale, None}, crop will construct a mosaic by slicing images
        togther, scale will create a mosaic by concatnating and shifting the
        image, and None will default to scale and apply no post processing to
        the created mosaic.
      mosaic_center: `float` indicating how much to randomly deviate from the
Vishnu Banna's avatar
config  
Vishnu Banna committed
64
        from the center of the image when creating a mosaic.
65
      aug_scale_min: `float` indicating the minimum scaling value for image
Vishnu Banna's avatar
config  
Vishnu Banna committed
66
        scale jitter.
67
68
69
70
71
72
73
      aug_scale_max: `float` indicating the maximum scaling value for image
        scale jitter.
      aug_rand_angle: `float` indicating the maximum angle value for angle.
        angle will be changes between 0 and value.
      aug_rand_perspective: `float` ranging from 0.000 to 0.001 indicating how
        much to prespective warp the image.
      aug_rand_translate: `float` ranging from 0 to 1 indicating the maximum
Vishnu Banna's avatar
config  
Vishnu Banna committed
74
        amount to randomly translate an image.
75
      random_pad: `bool` indiccating wether to use padding to apply random
Vishnu Banna's avatar
config  
Vishnu Banna committed
76
        translation true for darknet yolo false for scaled yolo.
77
78
      random_flip: `bool` whether or not to random flip the image.
      area_thresh: `float` for the minimum area of a box to allow to pass
Vishnu Banna's avatar
config  
Vishnu Banna committed
79
        through for optimization.
80
81
      pad_value: `int` padding value.
      seed: `int` the seed for random number generation.
Vishnu Banna's avatar
config  
Vishnu Banna committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    """

    self._output_size = output_size
    self._area_thresh = area_thresh

    self._mosaic_frequency = mosaic_frequency
    self._mixup_frequency = mixup_frequency

    self._letter_box = letter_box
    self._random_crop = jitter

    self._mosaic_crop_mode = mosaic_crop_mode
    self._mosaic_center = mosaic_center

    self._aug_scale_min = aug_scale_min
    self._aug_scale_max = aug_scale_max
    self._random_pad = random_pad
    self._aug_rand_translate = aug_rand_translate
    self._aug_rand_angle = aug_rand_angle
    self._aug_rand_perspective = aug_rand_perspective
Vishnu Banna's avatar
Vishnu Banna committed
102
    self._random_flip = random_flip
Vishnu Banna's avatar
Vishnu Banna committed
103
    self._pad_value = pad_value
Vishnu Banna's avatar
config  
Vishnu Banna committed
104

105
    self._deterministic = seed is not None
Vishnu Banna's avatar
config  
Vishnu Banna committed
106
107
108
109
110
111
    self._seed = seed if seed is not None else random.randint(0, 2**30)

  def _generate_cut(self):
    """Generate a random center to use for slicing and patching the images."""
    if self._mosaic_crop_mode == 'crop':
      min_offset = self._mosaic_center
Vishnu Banna's avatar
Vishnu Banna committed
112
      cut_x = preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
113
114
115
          self._output_size[1] * min_offset,
          self._output_size[1] * (1 - min_offset),
          seed=self._seed)
Vishnu Banna's avatar
Vishnu Banna committed
116
      cut_y = preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
117
118
119
          self._output_size[0] * min_offset,
          self._output_size[0] * (1 - min_offset),
          seed=self._seed)
Vishnu Banna's avatar
Vishnu Banna committed
120
      cut = [cut_y, cut_x]
Vishnu Banna's avatar
config  
Vishnu Banna committed
121
      ishape = tf.convert_to_tensor(
Vishnu Banna's avatar
Vishnu Banna committed
122
          [self._output_size[0], self._output_size[1], 3])
Vishnu Banna's avatar
config  
Vishnu Banna committed
123
124
125
    else:
      cut = None
      ishape = tf.convert_to_tensor(
Vishnu Banna's avatar
Vishnu Banna committed
126
          [self._output_size[0] * 2, self._output_size[1] * 2, 3])
Vishnu Banna's avatar
config  
Vishnu Banna committed
127
128
    return cut, ishape

Vishnu Banna's avatar
Vishnu Banna committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  def scale_boxes(self, patch, ishape, boxes, classes, xs, ys):
    """Scale and translate the boxes for each image prior to patching."""
    xs = tf.cast(xs, boxes.dtype)
    ys = tf.cast(ys, boxes.dtype)
    pshape = tf.cast(tf.shape(patch), boxes.dtype)
    ishape = tf.cast(ishape, boxes.dtype)
    translate = tf.cast((ishape - pshape), boxes.dtype)

    boxes = box_ops.denormalize_boxes(boxes, pshape[:2])
    boxes = boxes + tf.cast([
        translate[0] * ys, translate[1] * xs, translate[0] * ys,
        translate[1] * xs
    ], boxes.dtype)
    boxes = box_ops.normalize_boxes(boxes, ishape[:2])
    return boxes, classes

Vishnu Banna's avatar
Vishnu Banna committed
145
146
147
148
149
150
  def _select_ind(self, inds, *args):
    items = []
    for item in args:
      items.append(tf.gather(item, inds))
    return items

Vishnu Banna's avatar
config  
Vishnu Banna committed
151
152
153
154
155
156
157
158
159
160
  def _augment_image(self,
                     image,
                     boxes,
                     classes,
                     is_crowd,
                     area,
                     xs=0.0,
                     ys=0.0,
                     cut=None):
    """Process a single image prior to the application of patching."""
Vishnu Banna's avatar
Vishnu Banna committed
161
162
163
164
    if self._random_flip:
      # Randomly flip the image horizontally.
      image, boxes, _ = preprocess_ops.random_horizontal_flip(
          image, boxes, seed=self._seed)
Vishnu Banna's avatar
config  
Vishnu Banna committed
165

166
    # Augment the image without resizing
Vishnu Banna's avatar
config  
Vishnu Banna committed
167
168
169
    image, infos, crop_points = preprocessing_ops.resize_and_jitter_image(
        image, [self._output_size[0], self._output_size[1]],
        random_pad=False,
Vishnu Banna's avatar
Vishnu Banna committed
170
        letter_box=self._letter_box,
Vishnu Banna's avatar
config  
Vishnu Banna committed
171
172
173
174
175
176
177
        jitter=self._random_crop,
        shiftx=xs,
        shifty=ys,
        cut=cut,
        seed=self._seed)

    # Clip and clean boxes.
Vishnu Banna's avatar
Vishnu Banna committed
178
    boxes, inds = preprocessing_ops.transform_and_clip_boxes(
Vishnu Banna's avatar
config  
Vishnu Banna committed
179
180
181
182
        boxes,
        infos,
        area_thresh=self._area_thresh,
        shuffle_boxes=False,
183
        filter_and_clip_boxes=True,
Vishnu Banna's avatar
config  
Vishnu Banna committed
184
        seed=self._seed)
185
    classes, is_crowd, area = self._select_ind(inds, classes, is_crowd, area)  # pylint:disable=unbalanced-tuple-unpacking
Vishnu Banna's avatar
config  
Vishnu Banna committed
186
187
188
189
    return image, boxes, classes, is_crowd, area, crop_points

  def _mosaic_crop_image(self, image, boxes, classes, is_crowd, area):
    """Process a patched image in preperation for final output."""
190
    if self._mosaic_crop_mode != 'crop':
Vishnu Banna's avatar
config  
Vishnu Banna committed
191
192
193
      shape = tf.cast(preprocessing_ops.get_image_shape(image), tf.float32)
      center = shape * self._mosaic_center

194
      # shift the center of the image by applying a translation to the whole
Vishnu Banna's avatar
config  
Vishnu Banna committed
195
196
      # image
      ch = tf.math.round(
Vishnu Banna's avatar
Vishnu Banna committed
197
          preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
198
199
              -center[0], center[0], seed=self._seed))
      cw = tf.math.round(
Vishnu Banna's avatar
Vishnu Banna committed
200
          preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
201
202
203
              -center[1], center[1], seed=self._seed))

      # clip the boxes to those with in the image
Vishnu Banna's avatar
Vishnu Banna committed
204
      image = tfa.image.translate(image, [cw, ch], fill_value=self._pad_value)
Vishnu Banna's avatar
config  
Vishnu Banna committed
205
206
207
      boxes = box_ops.denormalize_boxes(boxes, shape[:2])
      boxes = boxes + tf.cast([ch, cw, ch, cw], boxes.dtype)
      boxes = box_ops.clip_boxes(boxes, shape[:2])
Vishnu Banna's avatar
Vishnu Banna committed
208
209
      inds = box_ops.get_non_empty_box_indices(boxes)

Vishnu Banna's avatar
config  
Vishnu Banna committed
210
      boxes = box_ops.normalize_boxes(boxes, shape[:2])
211
212
      boxes, classes, is_crowd, area = self._select_ind(inds, boxes, classes,  # pylint:disable=unbalanced-tuple-unpacking
                                                        is_crowd, area)
Vishnu Banna's avatar
config  
Vishnu Banna committed
213

214
    # warp and scale the fully stitched sample
Vishnu Banna's avatar
config  
Vishnu Banna committed
215
216
217
218
219
220
221
222
223
224
225
226
227
    image, _, affine = preprocessing_ops.affine_warp_image(
        image, [self._output_size[0], self._output_size[1]],
        scale_min=self._aug_scale_min,
        scale_max=self._aug_scale_max,
        translate=self._aug_rand_translate,
        degrees=self._aug_rand_angle,
        perspective=self._aug_rand_perspective,
        random_pad=self._random_pad,
        seed=self._seed)
    height, width = self._output_size[0], self._output_size[1]
    image = tf.image.resize(image, (height, width))

    # clip and clean boxes
Vishnu Banna's avatar
Vishnu Banna committed
228
    boxes, inds = preprocessing_ops.transform_and_clip_boxes(
229
230
231
232
        boxes,
        None,
        affine=affine,
        area_thresh=self._area_thresh,
Vishnu Banna's avatar
config  
Vishnu Banna committed
233
        seed=self._seed)
234
    classes, is_crowd, area = self._select_ind(inds, classes, is_crowd, area)  # pylint:disable=unbalanced-tuple-unpacking
Vishnu Banna's avatar
config  
Vishnu Banna committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    return image, boxes, classes, is_crowd, area, area

  # mosaic full frequency doubles model speed
  def _process_image(self, sample, shiftx, shifty, cut, ishape):
    """Process and augment each image."""
    (image, boxes, classes, is_crowd, area, crop_points) = self._augment_image(
        sample['image'], sample['groundtruth_boxes'],
        sample['groundtruth_classes'], sample['groundtruth_is_crowd'],
        sample['groundtruth_area'], shiftx, shifty, cut)

    (boxes, classes) = self.scale_boxes(image, ishape, boxes, classes,
                                        1 - shiftx, 1 - shifty)

    sample['image'] = image
    sample['groundtruth_boxes'] = boxes
    sample['groundtruth_classes'] = classes
    sample['groundtruth_is_crowd'] = is_crowd
    sample['groundtruth_area'] = area
    sample['shiftx'] = shiftx
    sample['shifty'] = shifty
    sample['crop_points'] = crop_points
    return sample

  def _patch2(self, one, two):
259
    """Stitch together 2 images in totality."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
260
    sample = one
261
    sample['image'] = tf.concat([one['image'], two['image']], axis=-2)
Vishnu Banna's avatar
config  
Vishnu Banna committed
262
263
264
265
266
267
268
269
270
271
272
273
274

    sample['groundtruth_boxes'] = tf.concat(
        [one['groundtruth_boxes'], two['groundtruth_boxes']], axis=0)
    sample['groundtruth_classes'] = tf.concat(
        [one['groundtruth_classes'], two['groundtruth_classes']], axis=0)
    sample['groundtruth_is_crowd'] = tf.concat(
        [one['groundtruth_is_crowd'], two['groundtruth_is_crowd']], axis=0)
    sample['groundtruth_area'] = tf.concat(
        [one['groundtruth_area'], two['groundtruth_area']], axis=0)
    return sample

  def _patch(self, one, two):
    """Build the full 4 patch of images from sets of 2 images."""
275
    image = tf.concat([one['image'], two['image']], axis=-3)
Vishnu Banna's avatar
config  
Vishnu Banna committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    boxes = tf.concat([one['groundtruth_boxes'], two['groundtruth_boxes']],
                      axis=0)
    classes = tf.concat(
        [one['groundtruth_classes'], two['groundtruth_classes']], axis=0)
    is_crowd = tf.concat(
        [one['groundtruth_is_crowd'], two['groundtruth_is_crowd']], axis=0)
    area = tf.concat([one['groundtruth_area'], two['groundtruth_area']], axis=0)

    if self._mosaic_crop_mode is not None:
      image, boxes, classes, is_crowd, area, _ = self._mosaic_crop_image(
          image, boxes, classes, is_crowd, area)

    sample = one
    height, width = preprocessing_ops.get_image_shape(image)
    sample['image'] = tf.cast(image, tf.uint8)
    sample['groundtruth_boxes'] = boxes
    sample['groundtruth_area'] = area
    sample['groundtruth_classes'] = tf.cast(classes,
                                            sample['groundtruth_classes'].dtype)
    sample['groundtruth_is_crowd'] = tf.cast(is_crowd, tf.bool)
    sample['width'] = tf.cast(width, sample['width'].dtype)
    sample['height'] = tf.cast(height, sample['height'].dtype)
    sample['num_detections'] = tf.shape(sample['groundtruth_boxes'])[1]
    sample['is_mosaic'] = tf.cast(1.0, tf.bool)

Vishnu Banna's avatar
Vishnu Banna committed
301
302
303
    del sample['shiftx']
    del sample['shifty']
    del sample['crop_points']
Vishnu Banna's avatar
config  
Vishnu Banna committed
304
305
306
307
308
309
310
    return sample

  def _mosaic(self, one, two, three, four):
    """Stitch together 4 images to build a mosaic."""
    if self._mosaic_frequency >= 1.0:
      domo = 1.0
    else:
Vishnu Banna's avatar
Vishnu Banna committed
311
      domo = preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
          0.0, 1.0, dtype=tf.float32, seed=self._seed)
      noop = one.copy()

    if domo >= (1 - self._mosaic_frequency):
      cut, ishape = self._generate_cut()
      one = self._process_image(one, 1.0, 1.0, cut, ishape)
      two = self._process_image(two, 0.0, 1.0, cut, ishape)
      three = self._process_image(three, 1.0, 0.0, cut, ishape)
      four = self._process_image(four, 0.0, 0.0, cut, ishape)
      patch1 = self._patch2(one, two)
      patch2 = self._patch2(three, four)
      stitched = self._patch(patch1, patch2)
      return stitched
    else:
      return self._add_param(noop)

Vishnu Banna's avatar
Vishnu Banna committed
328
  def _beta(self, alpha, beta):
Vishnu Banna's avatar
Vishnu Banna committed
329
    """Generates a random number using the beta distribution."""
Vishnu Banna's avatar
Vishnu Banna committed
330
331
332
333
    a = tf.random.gamma([], alpha)
    b = tf.random.gamma([], beta)
    return b / (a + b)

Vishnu Banna's avatar
config  
Vishnu Banna committed
334
335
336
337
338
  def _mixup(self, one, two):
    """Blend together 2 images for the mixup data augmentation."""
    if self._mixup_frequency >= 1.0:
      domo = 1.0
    else:
Vishnu Banna's avatar
Vishnu Banna committed
339
      domo = preprocessing_ops.random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
340
341
342
343
344
          0.0, 1.0, dtype=tf.float32, seed=self._seed)
      noop = one.copy()

    if domo >= (1 - self._mixup_frequency):
      sample = one
345
      otype = one['image'].dtype
Vishnu Banna's avatar
Vishnu Banna committed
346
347

      r = self._beta(8.0, 8.0)
Vishnu Banna's avatar
config  
Vishnu Banna committed
348
      sample['image'] = (
349
350
          r * tf.cast(one['image'], tf.float32) +
          (1 - r) * tf.cast(two['image'], tf.float32))
Vishnu Banna's avatar
config  
Vishnu Banna committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

      sample['image'] = tf.cast(sample['image'], otype)
      sample['groundtruth_boxes'] = tf.concat(
          [one['groundtruth_boxes'], two['groundtruth_boxes']], axis=0)
      sample['groundtruth_classes'] = tf.concat(
          [one['groundtruth_classes'], two['groundtruth_classes']], axis=0)
      sample['groundtruth_is_crowd'] = tf.concat(
          [one['groundtruth_is_crowd'], two['groundtruth_is_crowd']], axis=0)
      sample['groundtruth_area'] = tf.concat(
          [one['groundtruth_area'], two['groundtruth_area']], axis=0)
      return sample
    else:
      return self._add_param(noop)

  def _add_param(self, sample):
    """Add parameters to handle skipped images."""
    sample['is_mosaic'] = tf.cast(0.0, tf.bool)
    sample['num_detections'] = tf.shape(sample['groundtruth_boxes'])[0]
    return sample

  def _apply(self, dataset):
    """Apply mosaic to an input dataset."""
    determ = self._deterministic
Vishnu Banna's avatar
Vishnu Banna committed
374
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
Vishnu Banna's avatar
config  
Vishnu Banna committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    one = dataset.shuffle(100, seed=self._seed, reshuffle_each_iteration=True)
    two = dataset.shuffle(
        100, seed=self._seed + 1, reshuffle_each_iteration=True)
    three = dataset.shuffle(
        100, seed=self._seed + 2, reshuffle_each_iteration=True)
    four = dataset.shuffle(
        100, seed=self._seed + 3, reshuffle_each_iteration=True)

    dataset = tf.data.Dataset.zip((one, two, three, four))
    dataset = dataset.map(
        self._mosaic, num_parallel_calls=tf.data.AUTOTUNE, deterministic=determ)

    if self._mixup_frequency > 0:
      one = dataset.shuffle(
          100, seed=self._seed + 4, reshuffle_each_iteration=True)
      two = dataset.shuffle(
          100, seed=self._seed + 5, reshuffle_each_iteration=True)
      dataset = tf.data.Dataset.zip((one, two))
      dataset = dataset.map(
          self._mixup,
          num_parallel_calls=tf.data.AUTOTUNE,
          deterministic=determ)
    return dataset

  def _skip(self, dataset):
    """Skip samples in a dataset."""
    determ = self._deterministic
    return dataset.map(
        self._add_param,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=determ)

  def mosaic_fn(self, is_training=True):
408
    """Determine which function to apply based on whether model is training."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
409
410
411
412
    if is_training and self._mosaic_frequency > 0.0:
      return self._apply
    else:
      return self._skip