# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocessing ops.""" import functools import tensorflow as tf CROP_PROPORTION = 0.875 # Standard for ImageNet. def random_apply(func, p, x): """Randomly apply function func to x with probability p.""" return tf.cond( tf.less( tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32), tf.cast(p, tf.float32)), lambda: func(x), lambda: x) def random_brightness(image, max_delta, impl='simclrv2'): """A multiplicative vs additive change of brightness.""" if impl == 'simclrv2': factor = tf.random.uniform([], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta) image = image * factor elif impl == 'simclrv1': image = tf.image.random_brightness(image, max_delta=max_delta) else: raise ValueError('Unknown impl {} for random brightness.'.format(impl)) return image def to_grayscale(image, keep_channels=True): image = tf.image.rgb_to_grayscale(image) if keep_channels: image = tf.tile(image, [1, 1, 3]) return image def color_jitter_nonrand(image, brightness=0, contrast=0, saturation=0, hue=0, impl='simclrv2'): """Distorts the color of the image (jittering order is fixed). Args: image: The input image tensor. brightness: A float, specifying the brightness for color jitter. contrast: A float, specifying the contrast for color jitter. saturation: A float, specifying the saturation for color jitter. hue: A float, specifying the hue for color jitter. impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's version of random brightness. Returns: The distorted image tensor. """ with tf.name_scope('distort_color'): def apply_transform(i, x, brightness, contrast, saturation, hue): """Apply the i-th transformation.""" if brightness != 0 and i == 0: x = random_brightness(x, max_delta=brightness, impl=impl) elif contrast != 0 and i == 1: x = tf.image.random_contrast( x, lower=1 - contrast, upper=1 + contrast) elif saturation != 0 and i == 2: x = tf.image.random_saturation( x, lower=1 - saturation, upper=1 + saturation) elif hue != 0: x = tf.image.random_hue(x, max_delta=hue) return x for i in range(4): image = apply_transform(i, image, brightness, contrast, saturation, hue) image = tf.clip_by_value(image, 0., 1.) return image def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0, impl='simclrv2'): """Distorts the color of the image (jittering order is random). Args: image: The input image tensor. brightness: A float, specifying the brightness for color jitter. contrast: A float, specifying the contrast for color jitter. saturation: A float, specifying the saturation for color jitter. hue: A float, specifying the hue for color jitter. impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's version of random brightness. Returns: The distorted image tensor. """ with tf.name_scope('distort_color'): def apply_transform(i, x): """Apply the i-th transformation.""" def brightness_foo(): if brightness == 0: return x else: return random_brightness(x, max_delta=brightness, impl=impl) def contrast_foo(): if contrast == 0: return x else: return tf.image.random_contrast(x, lower=1 - contrast, upper=1 + contrast) def saturation_foo(): if saturation == 0: return x else: return tf.image.random_saturation( x, lower=1 - saturation, upper=1 + saturation) def hue_foo(): if hue == 0: return x else: return tf.image.random_hue(x, max_delta=hue) x = tf.cond(tf.less(i, 2), lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo)) return x perm = tf.random.shuffle(tf.range(4)) for i in range(4): image = apply_transform(perm[i], image) image = tf.clip_by_value(image, 0., 1.) return image def color_jitter(image, strength, random_order=True, impl='simclrv2'): """Distorts the color of the image. Args: image: The input image tensor. strength: the floating number for the strength of the color augmentation. random_order: A bool, specifying whether to randomize the jittering order. impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's version of random brightness. Returns: The distorted image tensor. """ brightness = 0.8 * strength contrast = 0.8 * strength saturation = 0.8 * strength hue = 0.2 * strength if random_order: return color_jitter_rand( image, brightness, contrast, saturation, hue, impl=impl) else: return color_jitter_nonrand( image, brightness, contrast, saturation, hue, impl=impl) def random_color_jitter(image, p=1.0, color_jitter_strength=1.0, impl='simclrv2'): """Perform random color jitter.""" def _transform(image): color_jitter_t = functools.partial( color_jitter, strength=color_jitter_strength, impl=impl) image = random_apply(color_jitter_t, p=0.8, x=image) return random_apply(to_grayscale, p=0.2, x=image) return random_apply(_transform, p=p, x=image) def gaussian_blur(image, kernel_size, sigma, padding='SAME'): """Blurs the given image with separable convolution. Args: image: Tensor of shape [height, width, channels] and dtype float to blur. kernel_size: Integer Tensor for the size of the blur kernel. This is should be an odd number. If it is an even number, the actual kernel size will be size + 1. sigma: Sigma value for gaussian operator. padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'. Returns: A Tensor representing the blurred image. """ radius = tf.cast(kernel_size / 2, dtype=tf.int32) kernel_size = radius * 2 + 1 x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32) blur_filter = tf.exp(-tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0))) blur_filter /= tf.reduce_sum(blur_filter) # One vertical and one horizontal filter. blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1]) blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1]) num_channels = tf.shape(image)[-1] blur_h = tf.tile(blur_h, [1, 1, num_channels, 1]) blur_v = tf.tile(blur_v, [1, 1, num_channels, 1]) expand_batch_dim = image.shape.ndims == 3 if expand_batch_dim: # Tensorflow requires batched input to convolutions, which we can fake with # an extra dimension. image = tf.expand_dims(image, axis=0) blurred = tf.nn.depthwise_conv2d( image, blur_h, strides=[1, 1, 1, 1], padding=padding) blurred = tf.nn.depthwise_conv2d( blurred, blur_v, strides=[1, 1, 1, 1], padding=padding) if expand_batch_dim: blurred = tf.squeeze(blurred, axis=0) return blurred def random_blur(image, height, width, p=0.5): """Randomly blur an image. Args: image: `Tensor` representing an image of arbitrary size. height: Height of output image. width: Width of output image. p: probability of applying this transformation. Returns: A preprocessed image `Tensor`. """ del width def _transform(image): sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32) return gaussian_blur( image, kernel_size=height // 10, sigma=sigma, padding='SAME') return random_apply(_transform, p=p, x=image) def distorted_bounding_box_crop(image, bbox, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), max_attempts=100, scope=None): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. Args: image: `Tensor` of image data. bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where each coordinate is [0, 1) and the coordinates are arranged as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area of the image must contain at least this fraction of any bounding box supplied. aspect_ratio_range: An optional list of `float`s. The cropped area of the image must have an aspect ratio = width / height within this range. area_range: An optional list of `float`s. The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts: An optional `int`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the entire image. scope: Optional `str` for name scope. Returns: (cropped image `Tensor`, distorted bbox `Tensor`). """ with tf.name_scope(scope or 'distorted_bounding_box_crop'): shape = tf.shape(image) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, use_image_if_no_bounding_boxes=True) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) image = tf.image.crop_to_bounding_box( image, offset_y, offset_x, target_height, target_width) return image def crop_and_resize(image, height, width): """Make a random crop and resize it to height `height` and width `width`. Args: image: Tensor representing the image. height: Desired image height. width: Desired image width. Returns: A `height` x `width` x channels Tensor holding a random crop of `image`. """ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) aspect_ratio = width / height image = distorted_bounding_box_crop( image, bbox, min_object_covered=0.1, aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio), area_range=(0.08, 1.0), max_attempts=100, scope=None) return tf.image.resize([image], [height, width], method=tf.image.ResizeMethod.BICUBIC)[0] def random_crop_with_resize(image, height, width, p=1.0): """Randomly crop and resize an image. Args: image: `Tensor` representing an image of arbitrary size. height: Height of output image. width: Width of output image. p: Probability of applying this transformation. Returns: A preprocessed image `Tensor`. """ def _transform(image): # pylint: disable=missing-docstring image = crop_and_resize(image, height, width) return image return random_apply(_transform, p=p, x=image)