Commit 2f3d9b1f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465437870
parent 0f6555a4
......@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10
is_multilabel: bool = False
aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
......
......@@ -68,6 +68,7 @@ class Parser(parser.Parser):
label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_crop: Optional[bool] = True,
aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
......@@ -85,6 +86,8 @@ class Parser(parser.Parser):
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
......@@ -98,6 +101,7 @@ class Parser(parser.Parser):
"""
self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip
self._aug_crop = aug_crop
self._num_classes = num_classes
self._image_field_key = image_field_key
if dtype == 'float32':
......@@ -168,7 +172,7 @@ class Parser(parser.Parser):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only:
if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
......@@ -184,6 +188,7 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3])
# Crops image.
if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond(
......@@ -227,7 +232,7 @@ class Parser(parser.Parser):
"""Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only:
if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
......@@ -238,6 +243,7 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3])
# Center crops.
if self._aug_crop:
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
......
......@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_crop=params.aug_crop,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment