preprocessing_ops.py 35.4 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Preprocessing ops for yolo."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
16
17
import random

18
19
import numpy as np
import tensorflow as tf
Vishnu Banna's avatar
config  
Vishnu Banna committed
20
import tensorflow_addons as tfa
21

Abdullah Rashwan's avatar
Abdullah Rashwan committed
22
from official.vision.ops import box_ops as bbox_ops
Vishnu Banna's avatar
config  
Vishnu Banna committed
23
24
25
26

PAD_VALUE = 114
GLOBAL_SEED_SET = False

27

Vishnu Banna's avatar
config  
Vishnu Banna committed
28
29
30
def set_random_seeds(seed=0):
  """Sets all accessible global seeds to properly apply randomization.

31
32
33
34
  This is not the same as passing the seed as a variable to each call
  to tf.random.For more, see the documentation for tf.random on the tensorflow
  website https://www.tensorflow.org/api_docs/python/tf/random/set_seed. Note
  that passing the seed to each random number generator will not give you the
Vishnu Banna's avatar
Vishnu Banna committed
35
  expected behavior if you use more than one generator in a single function.
Vishnu Banna's avatar
config  
Vishnu Banna committed
36

37
  Args:
Vishnu Banna's avatar
config  
Vishnu Banna committed
38
    seed: `Optional[int]` representing the seed you want to use.
39
  """
Vishnu Banna's avatar
config  
Vishnu Banna committed
40
41
42
43
44
45
46
  if seed is not None:
    global GLOBAL_SEED_SET
    random.seed(seed)
    GLOBAL_SEED_SET = True
  tf.random.set_seed(seed)
  np.random.seed(seed)

47
48
49
50
51
52
53
54

def random_uniform_strong(minval,
                          maxval,
                          dtype=tf.float32,
                          seed=None,
                          shape=None):
  """A unified function for consistent random number generation.

Vishnu Banna's avatar
config  
Vishnu Banna committed
55
56
  Equivalent to tf.random.uniform, except that minval and maxval are flipped if
  minval is greater than maxval. Seed Safe random number generator.
57

Vishnu Banna's avatar
config  
Vishnu Banna committed
58
59
60
61
62
  Args:
    minval: An `int` for a lower or upper endpoint of the interval from which to
      choose the random number.
    maxval: An `int` for the other endpoint.
    dtype: The output type of the tensor.
63
64
65
    seed: An `int` used to set the seed.
    shape: List or 1D tf.Tensor, output shape of the random generator.

Vishnu Banna's avatar
config  
Vishnu Banna committed
66
  Returns:
67
    A random tensor of type `dtype` that falls between `minval` and `maxval`
Vishnu Banna's avatar
Vishnu Banna committed
68
    excluding the larger one.
Vishnu Banna's avatar
config  
Vishnu Banna committed
69
70
71
72
73
74
75
  """
  if GLOBAL_SEED_SET:
    seed = None

  if minval > maxval:
    minval, maxval = maxval, minval
  return tf.random.uniform(
76
      shape=shape or [], minval=minval, maxval=maxval, seed=seed, dtype=dtype)
Vishnu Banna's avatar
config  
Vishnu Banna committed
77
78


Vishnu Banna's avatar
Vishnu Banna committed
79
def random_scale(val, dtype=tf.float32, seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
80
  """Generates a random number for scaling a parameter by multiplication.
Vishnu Banna's avatar
config  
Vishnu Banna committed
81

82
83
84
  Generates a random number for the scale. Half of the time, the value is
  between [1.0, val) with uniformly distributed probability. In the other half,
  the value is the reciprocal of this value. The function is identical to the
Vishnu Banna's avatar
Vishnu Banna committed
85
  one in the original implementation:
Vishnu Banna's avatar
config  
Vishnu Banna committed
86
  https://github.com/AlexeyAB/darknet/blob/a3714d0a/src/utils.c#L708-L713
87

Vishnu Banna's avatar
config  
Vishnu Banna committed
88
89
90
  Args:
    val: A float representing the maximum scaling allowed.
    dtype: The output type of the tensor.
91
92
    seed: An `int` used to set the seed.

Vishnu Banna's avatar
config  
Vishnu Banna committed
93
94
95
  Returns:
    The random scale.
  """
Vishnu Banna's avatar
Vishnu Banna committed
96
97
  scale = random_uniform_strong(1.0, val, dtype=dtype, seed=seed)
  do_ret = random_uniform_strong(minval=0, maxval=2, dtype=tf.int32, seed=seed)
98
  if do_ret == 1:
Vishnu Banna's avatar
config  
Vishnu Banna committed
99
100
101
102
103
    return scale
  return 1.0 / scale


def pad_max_instances(value, instances, pad_value=0, pad_axis=0):
Vishnu Banna's avatar
Vishnu Banna committed
104
  """Pad or clip the tensor value to a fixed length along a given axis.
Vishnu Banna's avatar
config  
Vishnu Banna committed
105

Vishnu Banna's avatar
Vishnu Banna committed
106
  Pads a dimension of the tensor to have a maximum number of instances filling
107
  additional entries with the `pad_value`. Allows for selection of the padding
Vishnu Banna's avatar
Vishnu Banna committed
108
  axis.
109

Vishnu Banna's avatar
config  
Vishnu Banna committed
110
111
  Args:
    value: An input tensor.
Vishnu Banna's avatar
Vishnu Banna committed
112
    instances: An `int` representing the maximum number of instances.
113
    pad_value: An `int` representing the value used for padding until the
Vishnu Banna's avatar
Vishnu Banna committed
114
115
      maximum number of instances is obtained.
    pad_axis: An `int` representing the axis index to pad.
116

Vishnu Banna's avatar
config  
Vishnu Banna committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
  Returns:
    The output tensor whose dimensions match the input tensor except with the
    size along the `pad_axis` replaced by `instances`.
  """

  # get the real shape of value
  shape = tf.shape(value)

  # compute the padding axis
  if pad_axis < 0:
    pad_axis = tf.rank(value) + pad_axis

  # determin how much of the tensor value to keep
  dim1 = shape[pad_axis]
  take = tf.math.reduce_min([instances, dim1])
  value, _ = tf.split(value, [take, -1], axis=pad_axis)

  # pad the clipped tensor to the right shape
  pad = tf.convert_to_tensor([tf.math.reduce_max([instances - dim1, 0])])
  nshape = tf.concat([shape[:pad_axis], pad, shape[(pad_axis + 1):]], axis=0)
  pad_tensor = tf.fill(nshape, tf.cast(pad_value, dtype=value.dtype))
  value = tf.concat([value, pad_tensor], axis=pad_axis)
Vishnu Banna's avatar
Vishnu Banna committed
139
140
141
142
143

  if isinstance(instances, int):
    vshape = value.get_shape().as_list()
    vshape[pad_axis] = instances
    value.set_shape(vshape)
Vishnu Banna's avatar
config  
Vishnu Banna committed
144
145
146
147
  return value


def get_image_shape(image):
Vishnu Banna's avatar
Vishnu Banna committed
148
  """Consistently gets the width and height of the image.
149

Vishnu Banna's avatar
Vishnu Banna committed
150
  Gets the shape of the image regardless of if the image is in the
Vishnu Banna's avatar
config  
Vishnu Banna committed
151
  (batch_size, x, y, c) format or the (x, y, c) format.
152

Vishnu Banna's avatar
config  
Vishnu Banna committed
153
154
  Args:
    image: A tensor who has either 3 or 4 dimensions.
155

Vishnu Banna's avatar
config  
Vishnu Banna committed
156
  Returns:
157
    A tuple (height, width), where height is the height of the image
Vishnu Banna's avatar
Vishnu Banna committed
158
    and width is the width of the image.
Vishnu Banna's avatar
config  
Vishnu Banna committed
159
160
161
162
163
164
165
166
167
168
169
170
  """
  shape = tf.shape(image)
  if shape.get_shape().as_list()[0] == 4:
    width = shape[2]
    height = shape[1]
  else:
    width = shape[1]
    height = shape[0]
  return height, width


def _augment_hsv_darknet(image, rh, rs, rv, seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
171
  """Randomize the hue, saturation, and brightness via the darknet method."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
172
  if rh > 0.0:
Vishnu Banna's avatar
Vishnu Banna committed
173
174
    deltah = random_uniform_strong(-rh, rh, seed=seed)
    image = tf.image.adjust_hue(image, deltah)
Vishnu Banna's avatar
config  
Vishnu Banna committed
175
  if rs > 0.0:
Vishnu Banna's avatar
Vishnu Banna committed
176
177
    deltas = random_scale(rs, seed=seed)
    image = tf.image.adjust_saturation(image, deltas)
Vishnu Banna's avatar
config  
Vishnu Banna committed
178
  if rv > 0.0:
Vishnu Banna's avatar
Vishnu Banna committed
179
180
    deltav = random_scale(rv, seed=seed)
    image *= tf.cast(deltav, image.dtype)
Vishnu Banna's avatar
config  
Vishnu Banna committed
181
182
183
184
185
186
187

  # clip the values of the image between 0.0 and 1.0
  image = tf.clip_by_value(image, 0.0, 1.0)
  return image


def _augment_hsv_torch(image, rh, rs, rv, seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
188
  """Randomize the hue, saturation, and brightness via the pytorch method."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
189
190
191
192
193
  dtype = image.dtype
  image = tf.cast(image, tf.float32)
  image = tf.image.rgb_to_hsv(image)
  gen_range = tf.cast([rh, rs, rv], image.dtype)
  scale = tf.cast([180, 255, 255], image.dtype)
Vishnu Banna's avatar
Vishnu Banna committed
194
  r = random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
      -1, 1, shape=[3], dtype=image.dtype, seed=seed) * gen_range + 1

  image = tf.math.floor(tf.cast(image, scale.dtype) * scale)
  image = tf.math.floor(tf.cast(image, r.dtype) * r)
  h, s, v = tf.split(image, 3, axis=-1)
  h = h % 180
  s = tf.clip_by_value(s, 0, 255)
  v = tf.clip_by_value(v, 0, 255)

  image = tf.concat([h, s, v], axis=-1)
  image = tf.cast(image, scale.dtype) / scale
  image = tf.image.hsv_to_rgb(image)
  return tf.cast(image, dtype)


def image_rand_hsv(image, rh, rs, rv, seed=None, darknet=False):
211
  """Randomly alters the hue, saturation, and brightness of an image.
Vishnu Banna's avatar
config  
Vishnu Banna committed
212

213
  Args:
Vishnu Banna's avatar
Vishnu Banna committed
214
    image: `Tensor` of shape [None, None, 3] that needs to be altered.
215
    rh: `float32` used to indicate the maximum delta that can be multiplied to
Vishnu Banna's avatar
Vishnu Banna committed
216
      the hue.
217
    rs: `float32` used to indicate the maximum delta that can be multiplied to
Vishnu Banna's avatar
Vishnu Banna committed
218
      the saturation.
219
    rv: `float32` used to indicate the maximum delta that can be multiplied to
Vishnu Banna's avatar
Vishnu Banna committed
220
221
      the brightness.
    seed: `Optional[int]` for the seed to use in the random number generation.
222
    darknet: `bool` indicating whether the model was originally built in the
Vishnu Banna's avatar
Vishnu Banna committed
223
      Darknet or PyTorch library.
224

Vishnu Banna's avatar
config  
Vishnu Banna committed
225
  Returns:
Vishnu Banna's avatar
Vishnu Banna committed
226
    The HSV altered image in the same datatype as the input image.
Vishnu Banna's avatar
config  
Vishnu Banna committed
227
  """
Vishnu Banna's avatar
Vishnu Banna committed
228

Vishnu Banna's avatar
config  
Vishnu Banna committed
229
230
231
232
233
234
235
236
237
  if darknet:
    image = _augment_hsv_darknet(image, rh, rs, rv, seed=seed)
  else:
    image = _augment_hsv_torch(image, rh, rs, rv, seed=seed)
  return image


def mosaic_cut(image, original_width, original_height, width, height, center,
               ptop, pleft, pbottom, pright, shiftx, shifty):
238
239
240
241
242
243
244
  """Generates a random center location to use for the mosaic operation.

  Given a center location, cuts the input image into a slice that will be
  concatenated with other slices with the same center in order to construct
  a final mosaicked image.

  Args:
Vishnu Banna's avatar
Vishnu Banna committed
245
    image: `Tensor` of shape [None, None, 3] that needs to be altered.
246
247
248
249
250
    original_width: `float` value indicating the original width of the image.
    original_height: `float` value indicating the original height of the image.
    width: `float` value indicating the final width of the image.
    height: `float` value indicating the final height of the image.
    center: `float` value indicating the desired center of the final patched
Vishnu Banna's avatar
config  
Vishnu Banna committed
251
252
      image.
    ptop: `float` value indicating the top of the image without padding.
253
254
255
256
257
258
259
260
    pleft: `float` value indicating the left of the image without padding.
    pbottom: `float` value indicating the bottom of the image without padding.
    pright: `float` value indicating the right of the image without padding.
    shiftx: `float` 0.0 or 1.0 value indicating if the image is on the left or
      right.
    shifty: `float` 0.0 or 1.0 value indicating if the image is at the top or
      bottom.

Vishnu Banna's avatar
config  
Vishnu Banna committed
261
262
  Returns:
    image: The cropped image in the same datatype as the input image.
263
    crop_info: `float` tensor that is applied to the boxes in order to select
Vishnu Banna's avatar
config  
Vishnu Banna committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
      the boxes still contained within the image.
  """

  def cast(values, dtype):
    return [tf.cast(value, dtype) for value in values]

  with tf.name_scope('mosaic_cut'):
    center = tf.cast(center, width.dtype)
    zero = tf.cast(0.0, width.dtype)
    cut_x, cut_y = center[1], center[0]

    # Select the crop of the image to use
    left_shift = tf.minimum(
        tf.minimum(cut_x, tf.maximum(zero, -pleft * width / original_width)),
        width - cut_x)
    top_shift = tf.minimum(
        tf.minimum(cut_y, tf.maximum(zero, -ptop * height / original_height)),
        height - cut_y)
    right_shift = tf.minimum(
        tf.minimum(width - cut_x,
                   tf.maximum(zero, -pright * width / original_width)), cut_x)
    bot_shift = tf.minimum(
        tf.minimum(height - cut_y,
                   tf.maximum(zero, -pbottom * height / original_height)),
        cut_y)

    (left_shift, top_shift, right_shift, bot_shift,
     zero) = cast([left_shift, top_shift, right_shift, bot_shift, zero],
                  tf.float32)
    # Build a crop offset and a crop size tensor to use for slicing.
    crop_offset = [zero, zero, zero]
    crop_size = [zero - 1, zero - 1, zero - 1]
    if shiftx == 0.0 and shifty == 0.0:
      crop_offset = [top_shift, left_shift, zero]
      crop_size = [cut_y, cut_x, zero - 1]
    elif shiftx == 1.0 and shifty == 0.0:
      crop_offset = [top_shift, cut_x - right_shift, zero]
      crop_size = [cut_y, width - cut_x, zero - 1]
    elif shiftx == 0.0 and shifty == 1.0:
      crop_offset = [cut_y - bot_shift, left_shift, zero]
      crop_size = [height - cut_y, cut_x, zero - 1]
    elif shiftx == 1.0 and shifty == 1.0:
      crop_offset = [cut_y - bot_shift, cut_x - right_shift, zero]
      crop_size = [height - cut_y, width - cut_x, zero - 1]

    # Contain and crop the image.
    ishape = tf.cast(tf.shape(image)[:2], crop_size[0].dtype)
    crop_size[0] = tf.minimum(crop_size[0], ishape[0])
    crop_size[1] = tf.minimum(crop_size[1], ishape[1])

    crop_offset = tf.cast(crop_offset, tf.int32)
    crop_size = tf.cast(crop_size, tf.int32)

    image = tf.slice(image, crop_offset, crop_size)
    crop_info = tf.stack([
        tf.cast(ishape, tf.float32),
        tf.cast(tf.shape(image)[:2], dtype=tf.float32),
        tf.ones_like(ishape, dtype=tf.float32),
        tf.cast(crop_offset[:2], tf.float32)
    ])

  return image, crop_info


def resize_and_jitter_image(image,
                            desired_size,
                            jitter=0.0,
                            letter_box=None,
                            random_pad=True,
                            crop_only=False,
                            shiftx=0.5,
                            shifty=0.5,
                            cut=None,
                            method=tf.image.ResizeMethod.BILINEAR,
                            seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
339
  """Resize, Pad, and distort a given input image.
340

Vishnu Banna's avatar
Vishnu Banna committed
341
342
343
344
345
346
347
348
349
  Args:
    image: a `Tensor` of shape [height, width, 3] representing an image.
    desired_size: a `Tensor` or `int` list/tuple of two elements representing
      [height, width] of the desired actual output image size.
    jitter: an `int` representing the maximum jittering that can be applied to
      the image.
    letter_box: a `bool` representing if letterboxing should be applied.
    random_pad: a `bool` representing if random padding should be applied.
    crop_only: a `bool` representing if only cropping will be applied.
350
351
    shiftx: a `float` indicating if the image is in the left or right.
    shifty: a `float` value indicating if the image is in the top or bottom.
Vishnu Banna's avatar
Vishnu Banna committed
352
353
354
355
    cut: a `float` value indicating the desired center of the final patched
      image.
    method: function to resize input image to scaled image.
    seed: seed for random scale jittering.
356

Vishnu Banna's avatar
Vishnu Banna committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
  Returns:
    image_: a `Tensor` of shape [height, width, 3] where [height, width]
      equals to `desired_size`.
    infos: a 2D `Tensor` that encodes the information of the image and the
      applied preprocessing. It is in the format of
      [[original_height, original_width], [desired_height, desired_width],
        [y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
      desired_width] is the actual scaled image size, and [y_scale, x_scale] is
      the scaling factor, which is the ratio of
      scaled dimension / original dimension.
    cast([original_width, original_height, width, height, ptop, pleft, pbottom,
      pright], tf.float32): a `Tensor` containing the information of the image
        andthe applied preprocessing.
  """
Vishnu Banna's avatar
config  
Vishnu Banna committed
371
372

  def intersection(a, b):
373
    """Finds the intersection between 2 crops."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
374
375
376
377
378
379
380
381
382
383
    minx = tf.maximum(a[0], b[0])
    miny = tf.maximum(a[1], b[1])
    maxx = tf.minimum(a[2], b[2])
    maxy = tf.minimum(a[3], b[3])
    return tf.convert_to_tensor([minx, miny, maxx, maxy])

  def cast(values, dtype):
    return [tf.cast(value, dtype) for value in values]

  if jitter > 0.5 or jitter < 0:
384
    raise ValueError('maximum change in aspect ratio must be between 0 and 0.5')
Vishnu Banna's avatar
config  
Vishnu Banna committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

  with tf.name_scope('resize_and_jitter_image'):
    # Cast all parameters to a usable float data type.
    jitter = tf.cast(jitter, tf.float32)
    original_dtype, original_dims = image.dtype, tf.shape(image)[:2]

    # original width, original height, desigered width, desired height
    original_width, original_height, width, height = cast(
        [original_dims[1], original_dims[0], desired_size[1], desired_size[0]],
        tf.float32)

    # Compute the random delta width and height etc. and randomize the
    # location of the corner points.
    jitter_width = original_width * jitter
    jitter_height = original_height * jitter
Vishnu Banna's avatar
Vishnu Banna committed
400
    pleft = random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
401
        -jitter_width, jitter_width, jitter_width.dtype, seed=seed)
Vishnu Banna's avatar
Vishnu Banna committed
402
    pright = random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
403
        -jitter_width, jitter_width, jitter_width.dtype, seed=seed)
Vishnu Banna's avatar
Vishnu Banna committed
404
    ptop = random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
405
        -jitter_height, jitter_height, jitter_height.dtype, seed=seed)
Vishnu Banna's avatar
Vishnu Banna committed
406
    pbottom = random_uniform_strong(
Vishnu Banna's avatar
config  
Vishnu Banna committed
407
408
409
        -jitter_height, jitter_height, jitter_height.dtype, seed=seed)

    # Letter box the image.
410
411
412
    if letter_box:
      (image_aspect_ratio,
       input_aspect_ratio) = original_width / original_height, width / height
Vishnu Banna's avatar
config  
Vishnu Banna committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
      distorted_aspect = image_aspect_ratio / input_aspect_ratio

      delta_h, delta_w = 0.0, 0.0
      pullin_h, pullin_w = 0.0, 0.0
      if distorted_aspect > 1:
        delta_h = ((original_width / input_aspect_ratio) - original_height) / 2
      else:
        delta_w = ((original_height * input_aspect_ratio) - original_width) / 2

      ptop = ptop - delta_h - pullin_h
      pbottom = pbottom - delta_h - pullin_h
      pright = pright - delta_w - pullin_w
      pleft = pleft - delta_w - pullin_w

    # Compute the width and height to crop or pad too, and clip all crops to
    # to be contained within the image.
    swidth = original_width - pleft - pright
    sheight = original_height - ptop - pbottom
    src_crop = intersection([ptop, pleft, sheight + ptop, swidth + pleft],
                            [0, 0, original_height, original_width])

    # Random padding used for mosaic.
    h_ = src_crop[2] - src_crop[0]
    w_ = src_crop[3] - src_crop[1]
    if random_pad:
      rmh = tf.maximum(0.0, -ptop)
      rmw = tf.maximum(0.0, -pleft)
    else:
      rmw = (swidth - w_) * shiftx
      rmh = (sheight - h_) * shifty

    # Cast cropping params to usable dtype.
    src_crop = tf.cast(src_crop, tf.int32)

    # Compute padding parmeters.
    dst_shape = [rmh, rmw, rmh + h_, rmw + w_]
    ptop, pleft, pbottom, pright = dst_shape
    pad = dst_shape * tf.cast([1, 1, -1, -1], ptop.dtype)
    pad += tf.cast([0, 0, sheight, swidth], ptop.dtype)
    pad = tf.cast(pad, tf.int32)

    infos = []

    # Crop the image to desired size.
    cropped_image = tf.slice(
        image, [src_crop[0], src_crop[1], 0],
        [src_crop[2] - src_crop[0], src_crop[3] - src_crop[1], -1])
    crop_info = tf.stack([
        tf.cast(original_dims, tf.float32),
        tf.cast(tf.shape(cropped_image)[:2], dtype=tf.float32),
        tf.ones_like(original_dims, dtype=tf.float32),
        tf.cast(src_crop[:2], tf.float32)
    ])
    infos.append(crop_info)

    if crop_only:
      if not letter_box:
        h_, w_ = cast(get_image_shape(cropped_image), width.dtype)
        width = tf.cast(tf.round((w_ * width) / swidth), tf.int32)
        height = tf.cast(tf.round((h_ * height) / sheight), tf.int32)
        cropped_image = tf.image.resize(
            cropped_image, [height, width], method=method)
        cropped_image = tf.cast(cropped_image, original_dtype)
      return cropped_image, infos, cast([
          original_width, original_height, width, height, ptop, pleft, pbottom,
          pright
      ], tf.int32)

    # Pad the image to desired size.
    image_ = tf.pad(
        cropped_image, [[pad[0], pad[2]], [pad[1], pad[3]], [0, 0]],
Vishnu Banna's avatar
Vishnu Banna committed
484
        constant_values=PAD_VALUE)
485
486
487
488

    # Pad and scale info
    isize = tf.cast(tf.shape(image_)[:2], dtype=tf.float32)
    osize = tf.cast((desired_size[0], desired_size[1]), dtype=tf.float32)
Vishnu Banna's avatar
config  
Vishnu Banna committed
489
490
    pad_info = tf.stack([
        tf.cast(tf.shape(cropped_image)[:2], tf.float32),
491
492
493
        osize,
        osize/isize,
        (-tf.cast(pad[:2], tf.float32)*osize/isize)
Vishnu Banna's avatar
config  
Vishnu Banna committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    ])
    infos.append(pad_info)

    temp = tf.shape(image_)[:2]
    cond = temp > tf.cast(desired_size, temp.dtype)
    if tf.reduce_any(cond):
      size = tf.cast(desired_size, temp.dtype)
      size = tf.where(cond, size, temp)
      image_ = tf.image.resize(
          image_, (size[0], size[1]), method=tf.image.ResizeMethod.AREA)
      image_ = tf.cast(image_, original_dtype)

    image_ = tf.image.resize(
        image_, (desired_size[0], desired_size[1]),
        method=tf.image.ResizeMethod.BILINEAR,
        antialias=False)

    image_ = tf.cast(image_, original_dtype)
    if cut is not None:
      image_, crop_info = mosaic_cut(image_, original_width, original_height,
                                     width, height, cut, ptop, pleft, pbottom,
                                     pright, shiftx, shifty)
      infos.append(crop_info)
    return image_, infos, cast([
        original_width, original_height, width, height, ptop, pleft, pbottom,
        pright
    ], tf.float32)


def _build_transform(image,
                     perspective=0.00,
                     degrees=0.0,
                     scale_min=1.0,
                     scale_max=1.0,
                     translate=0.0,
                     random_pad=False,
                     desired_size=None,
                     seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
532
  """Builds a unified affine transformation to spatially augment the image."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552

  height, width = get_image_shape(image)
  ch = height = tf.cast(height, tf.float32)
  cw = width = tf.cast(width, tf.float32)
  deg_to_rad = lambda x: tf.cast(x, tf.float32) * np.pi / 180.0

  if desired_size is not None:
    desired_size = tf.cast(desired_size, tf.float32)
    ch = desired_size[0]
    cw = desired_size[1]

  # Compute the center of the image in the output resulution.
  center = tf.eye(3, dtype=tf.float32)
  center = tf.tensor_scatter_nd_update(center, [[0, 2], [1, 2]],
                                       [-cw / 2, -ch / 2])
  center_boxes = tf.tensor_scatter_nd_update(center, [[0, 2], [1, 2]],
                                             [cw / 2, ch / 2])

  # Compute a random rotation to apply.
  rotation = tf.eye(3, dtype=tf.float32)
Vishnu Banna's avatar
Vishnu Banna committed
553
  a = deg_to_rad(random_uniform_strong(-degrees, degrees, seed=seed))
Vishnu Banna's avatar
config  
Vishnu Banna committed
554
555
556
557
558
559
560
561
562
563
564
  cos = tf.math.cos(a)
  sin = tf.math.sin(a)
  rotation = tf.tensor_scatter_nd_update(rotation,
                                         [[0, 0], [0, 1], [1, 0], [1, 1]],
                                         [cos, -sin, sin, cos])
  rotation_boxes = tf.tensor_scatter_nd_update(rotation,
                                               [[0, 0], [0, 1], [1, 0], [1, 1]],
                                               [cos, sin, -sin, cos])

  # Compute a random prespective change to apply.
  prespective_warp = tf.eye(3)
565
566
  px = random_uniform_strong(-perspective, perspective, seed=seed)
  py = random_uniform_strong(-perspective, perspective, seed=seed)
Vishnu Banna's avatar
config  
Vishnu Banna committed
567
  prespective_warp = tf.tensor_scatter_nd_update(prespective_warp,
568
                                                 [[2, 0], [2, 1]], [px, py])
Vishnu Banna's avatar
config  
Vishnu Banna committed
569
570
  prespective_warp_boxes = tf.tensor_scatter_nd_update(prespective_warp,
                                                       [[2, 0], [2, 1]],
571
                                                       [-px, -py])
Vishnu Banna's avatar
config  
Vishnu Banna committed
572
573
574

  # Compute a random scaling to apply.
  scale = tf.eye(3, dtype=tf.float32)
Vishnu Banna's avatar
Vishnu Banna committed
575
  s = random_uniform_strong(scale_min, scale_max, seed=seed)
Vishnu Banna's avatar
config  
Vishnu Banna committed
576
577
578
579
580
581
582
583
584
  scale = tf.tensor_scatter_nd_update(scale, [[0, 0], [1, 1]], [1 / s, 1 / s])
  scale_boxes = tf.tensor_scatter_nd_update(scale, [[0, 0], [1, 1]], [s, s])

  # Compute a random Translation to apply.
  translation = tf.eye(3)
  if (random_pad and height * s < ch and width * s < cw):
    # The image is contained within the image and arbitrarily translated to
    # locations with in the image.
    center = center_boxes = tf.eye(3, dtype=tf.float32)
585
586
    tx = random_uniform_strong(-1, 0, seed=seed) * (cw / s - width)
    ty = random_uniform_strong(-1, 0, seed=seed) * (ch / s - height)
Vishnu Banna's avatar
config  
Vishnu Banna committed
587
588
589
590
  else:
    # The image can be translated outside of the output resolution window
    # but the image is translated relative to the output resolution not the
    # input image resolution.
591
592
    tx = random_uniform_strong(0.5 - translate, 0.5 + translate, seed=seed)
    ty = random_uniform_strong(0.5 - translate, 0.5 + translate, seed=seed)
Vishnu Banna's avatar
config  
Vishnu Banna committed
593
594
595
596
597
598

    # Center and Scale the image such that the window of translation is
    # contained to the output resolution.
    dx, dy = (width - cw / s) / width, (height - ch / s) / height
    sx, sy = 1 - dx, 1 - dy
    bx, by = dx / 2, dy / 2
599
    tx, ty = bx + (sx * tx), by + (sy * ty)
Vishnu Banna's avatar
config  
Vishnu Banna committed
600
601

    # Scale the translation to width and height of the image.
602
603
    tx *= width
    ty *= height
Vishnu Banna's avatar
config  
Vishnu Banna committed
604
605

  translation = tf.tensor_scatter_nd_update(translation, [[0, 2], [1, 2]],
606
                                            [tx, ty])
Vishnu Banna's avatar
config  
Vishnu Banna committed
607
  translation_boxes = tf.tensor_scatter_nd_update(translation, [[0, 2], [1, 2]],
608
                                                  [-tx, -ty])
Vishnu Banna's avatar
config  
Vishnu Banna committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630

  # Use repeated matric multiplications to combine all the image transforamtions
  # into a single unified augmentation operation M is applied to the image
  # Mb is to apply to the boxes. The order of matrix multiplication is
  # important. First, Translate, then Scale, then Rotate, then Center, then
  # finally alter the Prepsective.
  affine = (translation @ scale @ rotation @ center @ prespective_warp)
  affine_boxes = (
      prespective_warp_boxes @ center_boxes @ rotation_boxes @ scale_boxes
      @ translation_boxes)
  return affine, affine_boxes, s


def affine_warp_image(image,
                      desired_size,
                      perspective=0.00,
                      degrees=0.0,
                      scale_min=1.0,
                      scale_max=1.0,
                      translate=0.0,
                      random_pad=False,
                      seed=None):
Vishnu Banna's avatar
Vishnu Banna committed
631
  """Applies random spatial augmentation to the image.
632
633

  Args:
Vishnu Banna's avatar
Vishnu Banna committed
634
635
636
637
638
639
    image: A `Tensor` for the image.
    desired_size: A `tuple` for desired output image size.
    perspective: An `int` for the maximum that can be applied to random
      perspective change.
    degrees: An `int` for the maximum degrees that can be applied to random
      rotation.
640
641
642
643
    scale_min: An `int` for the minimum scaling factor that can be applied to
      random scaling.
    scale_max: An `int` for the maximum scaling factor that can be applied to
      random scaling.
Vishnu Banna's avatar
Vishnu Banna committed
644
645
646
647
    translate: An `int` for the maximum translation that can be applied to
      random translation.
    random_pad: A `bool` for using random padding.
    seed: An `Optional[int]` for the seed to use in random number generation.
648

Vishnu Banna's avatar
Vishnu Banna committed
649
650
651
  Returns:
    image: A `Tensor` representing the augmented image.
    affine_matrix: A `Tensor` representing the augmenting matrix for the image.
652
    affine_info: A `List` containing the size of the original image, the desired
Vishnu Banna's avatar
Vishnu Banna committed
653
654
      output_size of the image and the augmenting matrix for the boxes.
  """
Vishnu Banna's avatar
config  
Vishnu Banna committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

  # Build an image transformation matrix.
  image_size = tf.cast(get_image_shape(image), tf.float32)
  affine_matrix, affine_boxes, _ = _build_transform(
      image,
      perspective=perspective,
      degrees=degrees,
      scale_min=scale_min,
      scale_max=scale_max,
      translate=translate,
      random_pad=random_pad,
      desired_size=desired_size,
      seed=seed)
  affine = tf.reshape(affine_matrix, [-1])
  affine = tf.cast(affine[:-1], tf.float32)

  # Apply the transformation to image.
  image = tfa.image.transform(
      image,
      affine,
Vishnu Banna's avatar
Vishnu Banna committed
675
      fill_value=PAD_VALUE,
Vishnu Banna's avatar
config  
Vishnu Banna committed
676
677
678
679
      output_shape=desired_size,
      interpolation='bilinear')

  desired_size = tf.cast(desired_size, tf.float32)
Vishnu Banna's avatar
Vishnu Banna committed
680
681
  affine_info = [image_size, desired_size, affine_boxes]
  return image, affine_matrix, affine_info
Vishnu Banna's avatar
config  
Vishnu Banna committed
682
683
684


def affine_warp_boxes(affine, boxes, output_size, box_history):
685
686
  """Applies random rotation, random perspective change and random translation.

Vishnu Banna's avatar
Vishnu Banna committed
687
  and random scaling to the boxes.
688

Vishnu Banna's avatar
Vishnu Banna committed
689
  Args:
690
    affine: A `Tensor` for the augmenting matrix for the boxes.
Vishnu Banna's avatar
Vishnu Banna committed
691
    boxes: A `Tensor` for the boxes.
692
693
694
    output_size: A `list` of two integers, a two-element vector or a tensor such
      that all but the last dimensions are `broadcastable` to `boxes`. The last
      dimension is 2, which represents [height, width].
Vishnu Banna's avatar
Vishnu Banna committed
695
    box_history: A `Tensor` for the boxes history, which are the boxes that
696
697
698
699
      undergo the same augmentations as `boxes`, but no clipping was applied. We
      can keep track of how much changes are done to the boxes by keeping track
      of this tensor.

Vishnu Banna's avatar
Vishnu Banna committed
700
701
702
703
  Returns:
    clipped_boxes: A `Tensor` representing the augmented boxes.
    box_history: A `Tensor` representing the augmented box_history.
  """
Vishnu Banna's avatar
config  
Vishnu Banna committed
704
705

  def _get_corners(box):
706
    """Get the corner of each box as a tuple of (x, y) coordinates."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
707
708
709
710
711
712
713
714
    ymi, xmi, yma, xma = tf.split(box, 4, axis=-1)
    tl = tf.concat([xmi, ymi], axis=-1)
    bl = tf.concat([xmi, yma], axis=-1)
    tr = tf.concat([xma, ymi], axis=-1)
    br = tf.concat([xma, yma], axis=-1)
    return tf.concat([tl, bl, tr, br], axis=-1)

  def _corners_to_boxes(corner):
715
    """Convert (x, y) corners back into boxes [ymin, xmin, ymax, xmax]."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
716
717
718
719
720
721
722
723
724
725
    corner = tf.reshape(corner, [-1, 4, 2])
    y = corner[..., 1]
    x = corner[..., 0]
    y_min = tf.reduce_min(y, axis=-1)
    x_min = tf.reduce_min(x, axis=-1)
    y_max = tf.reduce_max(y, axis=-1)
    x_max = tf.reduce_max(x, axis=-1)
    return tf.stack([y_min, x_min, y_max, x_max], axis=-1)

  def _aug_boxes(affine_matrix, box):
Vishnu Banna's avatar
Vishnu Banna committed
726
    """Apply an affine transformation matrix M to the boxes augment boxes."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    corners = _get_corners(box)
    corners = tf.reshape(corners, [-1, 4, 2])
    z = tf.expand_dims(tf.ones_like(corners[..., 1]), axis=-1)
    corners = tf.concat([corners, z], axis=-1)

    corners = tf.transpose(
        tf.matmul(affine_matrix, corners, transpose_b=True), perm=(0, 2, 1))

    corners, p = tf.split(corners, [2, 1], axis=-1)
    corners /= p
    corners = tf.reshape(corners, [-1, 8])
    box = _corners_to_boxes(corners)
    return box

  boxes = _aug_boxes(affine, boxes)
  box_history = _aug_boxes(affine, box_history)

  clipped_boxes = bbox_ops.clip_boxes(boxes, output_size)
  return clipped_boxes, box_history


def boxes_candidates(clipped_boxes,
                     box_history,
                     wh_thr=2,
                     ar_thr=20,
                     area_thr=0.1):
753
754
  """Filters the boxes that don't satisfy the width/height and area constraints.

Vishnu Banna's avatar
Vishnu Banna committed
755
756
757
  Args:
    clipped_boxes: A `Tensor` for the boxes.
    box_history: A `Tensor` for the boxes history, which are the boxes that
758
759
760
      undergo the same augmentations as `boxes`, but no clipping was applied. We
      can keep track of how much changes are done to the boxes by keeping track
      of this tensor.
Vishnu Banna's avatar
Vishnu Banna committed
761
762
763
    wh_thr: An `int` for the width/height threshold.
    ar_thr: An `int` for the aspect ratio threshold.
    area_thr: An `int` for the area threshold.
764

Vishnu Banna's avatar
Vishnu Banna committed
765
766
767
  Returns:
    indices[:, 0]: A `Tensor` representing valid boxes after filtering.
  """
768
769
770
  if area_thr == 0.0:
    wh_thr = 0
    ar_thr = np.inf
Vishnu Banna's avatar
config  
Vishnu Banna committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
  area_thr = tf.math.abs(area_thr)

  # Get the scaled and shifted heights of the original
  # unclipped boxes.
  og_height = tf.maximum(box_history[:, 2] - box_history[:, 0], 0.0)
  og_width = tf.maximum(box_history[:, 3] - box_history[:, 1], 0.0)

  # Get the scaled and shifted heights of the clipped boxes.
  clipped_height = tf.maximum(clipped_boxes[:, 2] - clipped_boxes[:, 0], 0.0)
  clipped_width = tf.maximum(clipped_boxes[:, 3] - clipped_boxes[:, 1], 0.0)

  # Determine the aspect ratio of the clipped boxes.
  ar = tf.maximum(clipped_width / (clipped_height + 1e-16),
                  clipped_height / (clipped_width + 1e-16))

  # Ensure the clipped width adn height are larger than a preset threshold.
787
788
  conda = clipped_width >= wh_thr
  condb = clipped_height >= wh_thr
Vishnu Banna's avatar
config  
Vishnu Banna committed
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

  # Ensure the area of the clipped box is larger than the area threshold.
  area = (clipped_height * clipped_width) / (og_width * og_height + 1e-16)
  condc = area > area_thr

  # Ensure the aspect ratio is not too extreme.
  condd = ar < ar_thr

  cond = tf.expand_dims(
      tf.logical_and(
          tf.logical_and(conda, condb), tf.logical_and(condc, condd)),
      axis=-1)

  # Set all the boxes that fail the test to be equal to zero.
  indices = tf.where(cond)
  return indices[:, 0]


def resize_and_crop_boxes(boxes, image_scale, output_size, offset, box_history):
Vishnu Banna's avatar
Vishnu Banna committed
808
  """Resizes and crops the boxes.
809

Vishnu Banna's avatar
Vishnu Banna committed
810
811
812
813
814
815
816
817
  Args:
    boxes: A `Tensor` for the boxes.
    image_scale: A `Tensor` for the scaling factor of the image.
    output_size: A `list` of two integers, a two-element vector or a tensor such
      that all but the last dimensions are `broadcastable` to `boxes`. The last
      dimension is 2, which represents [height, width].
    offset: A `Tensor` for how much translation was applied to the image.
    box_history: A `Tensor` for the boxes history, which are the boxes that
818
819
820
821
      undergo the same augmentations as `boxes`, but no clipping was applied. We
      can keep track of how much changes are done to the boxes by keeping track
      of this tensor.

Vishnu Banna's avatar
Vishnu Banna committed
822
823
824
825
826
  Returns:
    clipped_boxes: A `Tensor` representing the augmented boxes.
    box_history: A `Tensor` representing the augmented box_history.
  """

Vishnu Banna's avatar
config  
Vishnu Banna committed
827
828
829
830
831
832
833
834
835
836
837
838
839
  # Shift and scale the input boxes.
  boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
  boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])

  # Check the hitory of the boxes.
  box_history *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
  box_history -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])

  # Clip the shifted and scaled boxes.
  clipped_boxes = bbox_ops.clip_boxes(boxes, output_size)
  return clipped_boxes, box_history


Vishnu Banna's avatar
Vishnu Banna committed
840
def transform_and_clip_boxes(boxes,
841
842
843
844
845
                             infos,
                             affine=None,
                             shuffle_boxes=False,
                             area_thresh=0.1,
                             seed=None,
846
                             filter_and_clip_boxes=True):
Vishnu Banna's avatar
Vishnu Banna committed
847
  """Clips and cleans the boxes.
848

Vishnu Banna's avatar
Vishnu Banna committed
849
850
  Args:
    boxes: A `Tensor` for the boxes.
851
    infos: A `list` that contains the image infos.
Vishnu Banna's avatar
Vishnu Banna committed
852
853
854
855
    affine: A `list` that contains parameters for resize and crop.
    shuffle_boxes: A `bool` for shuffling the boxes.
    area_thresh: An `int` for the area threshold.
    seed: seed for random number generation.
856
857
    filter_and_clip_boxes: A `bool` for filtering and clipping the boxes to
      [0, 1].
858

Vishnu Banna's avatar
Vishnu Banna committed
859
860
861
862
863
  Returns:
    boxes: A `Tensor` representing the augmented boxes.
    ind: A `Tensor` valid box indices.
  """

Vishnu Banna's avatar
config  
Vishnu Banna committed
864
865
866
867
868
869
870
871
872
873
874
875
876
877
  # Clip and clean boxes.
  def get_valid_boxes(boxes):
    """Get indices for non-empty boxes."""
    # Convert the boxes to center width height formatting.
    height = boxes[:, 2] - boxes[:, 0]
    width = boxes[:, 3] - boxes[:, 1]
    base = tf.logical_and(tf.greater(height, 0), tf.greater(width, 0))
    return base

  # Initialize history to track operation applied to boxes
  box_history = boxes

  # Make sure all boxes are valid to start, clip to [0, 1] and get only the
  # valid boxes.
878
879
  output_size = None
  if filter_and_clip_boxes:
Vishnu Banna's avatar
config  
Vishnu Banna committed
880
881
882
883
884
885
886
887
888
889
890
891
892
893
    boxes = tf.math.maximum(tf.math.minimum(boxes, 1.0), 0.0)
  cond = get_valid_boxes(boxes)

  if infos is None:
    infos = []

  for info in infos:
    # Denormalize the boxes.
    boxes = bbox_ops.denormalize_boxes(boxes, info[0])
    box_history = bbox_ops.denormalize_boxes(box_history, info[0])

    # Shift and scale all boxes, and keep track of box history with no
    # box clipping, history is used for removing boxes that have become
    # too small or exit the image area.
894
895
    (boxes, box_history) = resize_and_crop_boxes(
        boxes, info[2, :], info[1, :], info[3, :], box_history=box_history)
Vishnu Banna's avatar
config  
Vishnu Banna committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910

    # Get all the boxes that still remain in the image and store
    # in a bit vector for later use.
    cond = tf.logical_and(get_valid_boxes(boxes), cond)

    # Normalize the boxes to [0, 1].
    output_size = info[1]
    boxes = bbox_ops.normalize_boxes(boxes, output_size)
    box_history = bbox_ops.normalize_boxes(box_history, output_size)

  if affine is not None:
    # Denormalize the boxes.
    boxes = bbox_ops.denormalize_boxes(boxes, affine[0])
    box_history = bbox_ops.denormalize_boxes(box_history, affine[0])

911
912
913
    # Clipped final boxes.
    (boxes, box_history) = affine_warp_boxes(
        affine[2], boxes, affine[1], box_history=box_history)
Vishnu Banna's avatar
config  
Vishnu Banna committed
914
915
916
917
918
919
920
921
922
923
924
925
926
927

    # Get all the boxes that still remain in the image and store
    # in a bit vector for later use.
    cond = tf.logical_and(get_valid_boxes(boxes), cond)

    # Normalize the boxes to [0, 1].
    output_size = affine[1]
    boxes = bbox_ops.normalize_boxes(boxes, output_size)
    box_history = bbox_ops.normalize_boxes(box_history, output_size)

  # Remove the bad boxes.
  boxes *= tf.cast(tf.expand_dims(cond, axis=-1), boxes.dtype)

  # Threshold the existing boxes.
928
929
930
931
932
933
934
935
  if filter_and_clip_boxes:
    if output_size is not None:
      boxes_ = bbox_ops.denormalize_boxes(boxes, output_size)
      box_history_ = bbox_ops.denormalize_boxes(box_history, output_size)
      inds = boxes_candidates(boxes_, box_history_, area_thr=area_thresh)
    else:
      inds = boxes_candidates(
          boxes, box_history, wh_thr=0.0, area_thr=area_thresh)
Vishnu Banna's avatar
config  
Vishnu Banna committed
936
937
938
939
    # Select and gather the good boxes.
    if shuffle_boxes:
      inds = tf.random.shuffle(inds, seed=seed)
  else:
940
    inds = bbox_ops.get_non_empty_box_indices(boxes)
Vishnu Banna's avatar
config  
Vishnu Banna committed
941
  boxes = tf.gather(boxes, inds)
942
  return boxes, inds