Commit 9b47a723 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465437870
parent 02d00c0c
......@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10
is_multilabel: bool = False
aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
......
......@@ -68,6 +68,7 @@ class Parser(parser.Parser):
label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_crop: Optional[bool] = True,
aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
......@@ -85,6 +86,8 @@ class Parser(parser.Parser):
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
......@@ -98,6 +101,7 @@ class Parser(parser.Parser):
"""
self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip
self._aug_crop = aug_crop
self._num_classes = num_classes
self._image_field_key = image_field_key
if dtype == 'float32':
......@@ -168,7 +172,7 @@ class Parser(parser.Parser):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only:
if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
......@@ -184,12 +188,13 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3])
# Crops image.
cropped_image = preprocess_ops.random_crop_image(image)
if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
......@@ -227,7 +232,7 @@ class Parser(parser.Parser):
"""Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only:
if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
......@@ -238,7 +243,8 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3])
# Center crops.
image = preprocess_ops.center_crop_image(image)
if self._aug_crop:
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
......
......@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_crop=params.aug_crop,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment