Commit 9b47a723 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465437870
parent 02d00c0c
...@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig): ...@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10 cycle_length: int = 10
is_multilabel: bool = False is_multilabel: bool = False
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True
aug_type: Optional[ aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment. common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0. color_jitter: float = 0.
......
...@@ -68,6 +68,7 @@ class Parser(parser.Parser): ...@@ -68,6 +68,7 @@ class Parser(parser.Parser):
label_field_key: str = DEFAULT_LABEL_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True, decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_crop: Optional[bool] = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0., color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None, random_erasing: Optional[common.RandomErasing] = None,
...@@ -85,6 +86,8 @@ class Parser(parser.Parser): ...@@ -85,6 +86,8 @@ class Parser(parser.Parser):
faster than decoding other types. Default is True. faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to color_jitter: Magnitude of color jitter. If > 0, the value is used to
...@@ -98,6 +101,7 @@ class Parser(parser.Parser): ...@@ -98,6 +101,7 @@ class Parser(parser.Parser):
""" """
self._output_size = output_size self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip self._aug_rand_hflip = aug_rand_hflip
self._aug_crop = aug_crop
self._num_classes = num_classes self._num_classes = num_classes
self._image_field_key = image_field_key self._image_field_key = image_field_key
if dtype == 'float32': if dtype == 'float32':
...@@ -168,7 +172,7 @@ class Parser(parser.Parser): ...@@ -168,7 +172,7 @@ class Parser(parser.Parser):
"""Parses image data for training.""" """Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only: if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image. # Crops image.
...@@ -184,12 +188,13 @@ class Parser(parser.Parser): ...@@ -184,12 +188,13 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3]) image.set_shape([None, None, 3])
# Crops image. # Crops image.
cropped_image = preprocess_ops.random_crop_image(image) if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond( image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))), tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image), lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image) lambda: cropped_image)
if self._aug_rand_hflip: if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
...@@ -227,7 +232,7 @@ class Parser(parser.Parser): ...@@ -227,7 +232,7 @@ class Parser(parser.Parser):
"""Parses image data for evaluation.""" """Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only: if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops. # Center crops.
...@@ -238,7 +243,8 @@ class Parser(parser.Parser): ...@@ -238,7 +243,8 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3]) image.set_shape([None, None, 3])
# Center crops. # Center crops.
image = preprocess_ops.center_crop_image(image) if self._aug_crop:
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
......
...@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key=label_field_key, label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only, decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_crop=params.aug_crop,
aug_type=params.aug_type, aug_type=params.aug_type,
color_jitter=params.color_jitter, color_jitter=params.color_jitter,
random_erasing=params.random_erasing, random_erasing=params.random_erasing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment