Commit 2f3d9b1f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465437870
parent 0f6555a4
...@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig): ...@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10 cycle_length: int = 10
is_multilabel: bool = False is_multilabel: bool = False
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True
aug_type: Optional[ aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment. common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0. color_jitter: float = 0.
......
...@@ -68,6 +68,7 @@ class Parser(parser.Parser): ...@@ -68,6 +68,7 @@ class Parser(parser.Parser):
label_field_key: str = DEFAULT_LABEL_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True, decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_crop: Optional[bool] = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0., color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None, random_erasing: Optional[common.RandomErasing] = None,
...@@ -85,6 +86,8 @@ class Parser(parser.Parser): ...@@ -85,6 +86,8 @@ class Parser(parser.Parser):
faster than decoding other types. Default is True. faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to color_jitter: Magnitude of color jitter. If > 0, the value is used to
...@@ -98,6 +101,7 @@ class Parser(parser.Parser): ...@@ -98,6 +101,7 @@ class Parser(parser.Parser):
""" """
self._output_size = output_size self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip self._aug_rand_hflip = aug_rand_hflip
self._aug_crop = aug_crop
self._num_classes = num_classes self._num_classes = num_classes
self._image_field_key = image_field_key self._image_field_key = image_field_key
if dtype == 'float32': if dtype == 'float32':
...@@ -168,7 +172,7 @@ class Parser(parser.Parser): ...@@ -168,7 +172,7 @@ class Parser(parser.Parser):
"""Parses image data for training.""" """Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only: if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image. # Crops image.
...@@ -184,6 +188,7 @@ class Parser(parser.Parser): ...@@ -184,6 +188,7 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3]) image.set_shape([None, None, 3])
# Crops image. # Crops image.
if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image) cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond( image = tf.cond(
...@@ -227,7 +232,7 @@ class Parser(parser.Parser): ...@@ -227,7 +232,7 @@ class Parser(parser.Parser):
"""Parses image data for evaluation.""" """Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only: if self._decode_jpeg_only and self._aug_crop:
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops. # Center crops.
...@@ -238,6 +243,7 @@ class Parser(parser.Parser): ...@@ -238,6 +243,7 @@ class Parser(parser.Parser):
image.set_shape([None, None, 3]) image.set_shape([None, None, 3])
# Center crops. # Center crops.
if self._aug_crop:
image = preprocess_ops.center_crop_image(image) image = preprocess_ops.center_crop_image(image)
image = tf.image.resize( image = tf.image.resize(
......
...@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key=label_field_key, label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only, decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_crop=params.aug_crop,
aug_type=params.aug_type, aug_type=params.aug_type,
color_jitter=params.color_jitter, color_jitter=params.color_jitter,
random_erasing=params.random_erasing, random_erasing=params.random_erasing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment