Commit d4053d80 authored by Simon Geisler's avatar Simon Geisler
Browse files

fix issues with color jitter and random erase

parent 34c6530a
...@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config): ...@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config):
magnitude: float = 10 magnitude: float = 10
cutout_const: float = 40 cutout_const: float = 40
translate_const: float = 10 translate_const: float = 10
magnitude_std: float = 0.0
prob_to_apply: Optional[float] = None prob_to_apply: Optional[float] = None
exclude_ops: List[str] = dataclasses.field(default_factory=list) exclude_ops: List[str] = dataclasses.field(default_factory=list)
......
...@@ -196,6 +196,11 @@ class Parser(parser.Parser): ...@@ -196,6 +196,11 @@ class Parser(parser.Parser):
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3]) image.set_shape([self._output_size[0], self._output_size[1], 3])
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(
image, self._color_jitter, self._color_jitter, self._color_jitter)
# Apply autoaug or randaug. # Apply autoaug or randaug.
if self._augmenter is not None: if self._augmenter is not None:
image = self._augmenter.distort(image) image = self._augmenter.distort(image)
...@@ -205,6 +210,10 @@ class Parser(parser.Parser): ...@@ -205,6 +210,10 @@ class Parser(parser.Parser):
offset=MEAN_RGB, offset=MEAN_RGB,
scale=STDDEV_RGB) scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
...@@ -231,20 +240,11 @@ class Parser(parser.Parser): ...@@ -231,20 +240,11 @@ class Parser(parser.Parser):
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3]) image.set_shape([self._output_size[0], self._output_size[1], 3])
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(
image, self._color_jitter, self._color_jitter, self._color_jitter)
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB, offset=MEAN_RGB,
scale=STDDEV_RGB) scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
......
...@@ -1359,7 +1359,7 @@ class RandomErasing(ImageAugment): ...@@ -1359,7 +1359,7 @@ class RandomErasing(ImageAugment):
""" """
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0) uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, .5) mirror_cond = tf.less(uniform_random, .5)
tf.cond(mirror_cond, self._erase, lambda: image) tf.cond(mirror_cond, lambda: self._erase(image), lambda: image)
return image return image
@tf.function @tf.function
...@@ -1374,31 +1374,34 @@ class RandomErasing(ImageAugment): ...@@ -1374,31 +1374,34 @@ class RandomErasing(ImageAugment):
area = tf.cast(image_width * image_height, tf.float32) area = tf.cast(image_width * image_height, tf.float32)
for _ in range(count): for _ in range(count):
# Work around since break is not supported in tf.function
is_trial_successfull = False
for _ in range(self._trials): for _ in range(self._trials):
erase_area = tf.random.uniform(shape=[], if not is_trial_successfull:
minval=area * self._min_area, erase_area = tf.random.uniform(shape=[],
maxval=area * self._max_area) minval=area * self._min_area,
aspect_ratio = tf.math.exp(tf.random.uniform( maxval=area * self._max_area)
shape=[], minval=self._min_log_aspect, aspect_ratio = tf.math.exp(tf.random.uniform(
maxval=self._max_log_aspect)) shape=[], minval=self._min_log_aspect,
maxval=self._max_log_aspect))
half_height = tf.cast(tf.math.round(tf.math.sqrt(
erase_area * aspect_ratio) / 2), dtype=tf.int32) half_height = tf.cast(tf.math.round(tf.math.sqrt(
half_width = tf.cast(tf.math.round(tf.math.sqrt( erase_area * aspect_ratio) / 2), dtype=tf.int32)
erase_area / aspect_ratio) / 2), dtype=tf.int32) half_width = tf.cast(tf.math.round(tf.math.sqrt(
erase_area / aspect_ratio) / 2), dtype=tf.int32)
if 2 * half_height < image_height and 2 * half_width < image_width:
center_height = tf.random.uniform( if 2 * half_height < image_height and 2 * half_width < image_width:
shape=[], minval=0, maxval=int(image_height - 2 * half_height), center_height = tf.random.uniform(
dtype=tf.int32) shape=[], minval=0, maxval=int(image_height - 2 * half_height),
center_width = tf.random.uniform( dtype=tf.int32)
shape=[], minval=0, maxval=int(image_width - 2 * half_width), center_width = tf.random.uniform(
dtype=tf.int32) shape=[], minval=0, maxval=int(image_width - 2 * half_width),
dtype=tf.int32)
image = _fill_rectangle(image, center_width, center_height,
half_width, half_height, replace=None) image = _fill_rectangle(image, center_width, center_height,
half_width, half_height, replace=None)
break
is_trial_successfull = True
return image return image
......
...@@ -566,7 +566,7 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0., ...@@ -566,7 +566,7 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter. """Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args: Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image. image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0. Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0. contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
...@@ -575,8 +575,9 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0., ...@@ -575,8 +575,9 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
seed (int, optional): Random seed. Defaults to None. seed (int, optional): Random seed. Defaults to None.
Returns: Returns:
tf.Tensor: The augmented version of `image`. tf.Tensor: The augmented `image` of type uint8.
""" """
image = tf.cast(image, dtype=tf.uint8)
image = random_brightness(image, brightness, seed=seed) image = random_brightness(image, brightness, seed=seed)
image = random_contrast(image, contrast, seed=seed) image = random_contrast(image, contrast, seed=seed)
image = random_saturation(image, saturation, seed=seed) image = random_saturation(image, saturation, seed=seed)
...@@ -588,17 +589,17 @@ def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0., ...@@ -588,17 +589,17 @@ def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0.,
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter. """Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args: Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image. image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0. Defaults to 0.
seed (int, optional): Random seed. Defaults to None. seed (int, optional): Random seed. Defaults to None.
Returns: Returns:
tf.Tensor: The augmented version of `image`. tf.Tensor: The augmented `image` of type uint8.
""" """
assert brightness >= 0 and brightness <= 1., '`brightness` must be in [0, 1]' assert brightness >= 0, '`brightness` must be positive'
brightness = tf.random.uniform( brightness = tf.random.uniform(
[], max(0, 1 - brightness), 1 + brightness, seed=seed) [], max(0, 1 - brightness), 1 + brightness, seed=seed, dtype=tf.float32)
return augment.brightness(image, brightness) return augment.brightness(image, brightness)
...@@ -607,17 +608,17 @@ def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0., ...@@ -607,17 +608,17 @@ def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0.,
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter. """Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args: Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image. image (tf.Tensor): Of shape [height, width, 3] and type uint8.
contrast (float, optional): Magnitude for contrast jitter. contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0. Defaults to 0.
seed (int, optional): Random seed. Defaults to None. seed (int, optional): Random seed. Defaults to None.
Returns: Returns:
tf.Tensor: The augmented version of `image`. tf.Tensor: The augmented `image` of type uint8.
""" """
assert contrast >= 0 and contrast <= 1., '`contrast` must be in [0, 1]' assert contrast >= 0, '`contrast` must be positive'
contrast = tf.random.uniform( contrast = tf.random.uniform(
[], max(0, 1 - contrast), 1 + contrast, seed=seed) [], max(0, 1 - contrast), 1 + contrast, seed=seed, dtype=tf.float32)
return augment.contrast(image, contrast) return augment.contrast(image, contrast)
...@@ -626,15 +627,16 @@ def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0., ...@@ -626,15 +627,16 @@ def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter. """Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args: Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image. image (tf.Tensor): Of shape [height, width, 3] and type uint8.
saturation (float, optional): Magnitude for saturation jitter. saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0. Defaults to 0.
seed (int, optional): Random seed. Defaults to None. seed (int, optional): Random seed. Defaults to None.
Returns: Returns:
tf.Tensor: The augmented version of `image`. tf.Tensor: The augmented `image` of type uint8.
""" """
assert saturation >= 0 and saturation <= 1., '`saturation` must be in [0, 1]' assert saturation >= 0, '`saturation` must be positive'
saturation = tf.random.uniform( saturation = tf.random.uniform(
[], max(0, 1 - saturation), 1 + saturation, seed=seed) [], max(0, 1 - saturation), 1 + saturation, seed=seed, dtype=tf.float32)
return augment.blend(tf.image.rgb_to_grayscale(image), image, saturation) return augment.blend(
tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image, saturation)
...@@ -225,6 +225,17 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -225,6 +225,17 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_ = preprocess_ops.random_crop_image_v2( _ = preprocess_ops.random_crop_image_v2(
image_bytes, tf.constant([input_height, input_width, 3], tf.int32)) image_bytes, tf.constant([input_height, input_width, 3], tf.int32))
@parameterized.parameters(
(400, 600, 0), (400, 600, 0.4), (600, 400, 1.4)
)
def testColorJitter(self, input_height, input_width, color_jitter):
image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops.color_jitter(
image, color_jitter, color_jitter, color_jitter)
assert jittered_image.shape == image.shape
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment