Commit 242f4098 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 477486513
parent 051f1c96
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional, Tuple
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -37,6 +37,7 @@ class DataConfig(cfg.DataConfig): ...@@ -37,6 +37,7 @@ class DataConfig(cfg.DataConfig):
is_multilabel: bool = False is_multilabel: bool = False
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True aug_crop: Optional[bool] = True
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0)
aug_type: Optional[ aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment. common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0. color_jitter: float = 0.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Classification decoder and parser.""" """Classification decoder and parser."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -54,8 +54,8 @@ class Decoder(decoder.Decoder): ...@@ -54,8 +54,8 @@ class Decoder(decoder.Decoder):
self._keys_to_features = keys_to_features self._keys_to_features = keys_to_features
def decode(self, serialized_example): def decode(self, serialized_example):
return tf.io.parse_single_example( return tf.io.parse_single_example(serialized_example,
serialized_example, self._keys_to_features) self._keys_to_features)
class Parser(parser.Parser): class Parser(parser.Parser):
...@@ -73,7 +73,8 @@ class Parser(parser.Parser): ...@@ -73,7 +73,8 @@ class Parser(parser.Parser):
color_jitter: float = 0., color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None, random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
dtype: str = 'float32'): dtype: str = 'float32',
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0)):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
Args: Args:
...@@ -84,8 +85,8 @@ class Parser(parser.Parser): ...@@ -84,8 +85,8 @@ class Parser(parser.Parser):
label_field_key: `str`, the key name to label in tf.Example. label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True. faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random horizontal
horizontal flip. flip.
aug_crop: `bool`, if True, perform random cropping during training and aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation. center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
...@@ -98,6 +99,10 @@ class Parser(parser.Parser): ...@@ -98,6 +99,10 @@ class Parser(parser.Parser):
is_multilabel: A `bool`, whether or not each example has multiple labels. is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'. or 'bfloat16'.
crop_area_range: An optional `tuple` of (min_area, max_area) for image
random crop function to constraint crop operation. The cropped areas
of the image must contain a fraction of the input image within this
range. The default area range is (0.08, 1.0).
""" """
self._output_size = output_size self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip self._aug_rand_hflip = aug_rand_hflip
...@@ -147,6 +152,7 @@ class Parser(parser.Parser): ...@@ -147,6 +152,7 @@ class Parser(parser.Parser):
self._random_erasing = None self._random_erasing = None
self._is_multilabel = is_multilabel self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only self._decode_jpeg_only = decode_jpeg_only
self._crop_area_range = crop_area_range
def _parse_train_data(self, decoded_tensors): def _parse_train_data(self, decoded_tensors):
"""Parses data for training.""" """Parses data for training."""
...@@ -177,7 +183,7 @@ class Parser(parser.Parser): ...@@ -177,7 +183,7 @@ class Parser(parser.Parser):
# Crops image. # Crops image.
cropped_image = preprocess_ops.random_crop_image_v2( cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape) image_bytes, image_shape, area_range=self._crop_area_range)
image = tf.cond( image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)), tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape), lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
...@@ -189,7 +195,8 @@ class Parser(parser.Parser): ...@@ -189,7 +195,8 @@ class Parser(parser.Parser):
# Crops image. # Crops image.
if self._aug_crop: if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image) cropped_image = preprocess_ops.random_crop_image(
image, area_range=self._crop_area_range)
image = tf.cond( image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))), tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
...@@ -215,9 +222,8 @@ class Parser(parser.Parser): ...@@ -215,9 +222,8 @@ class Parser(parser.Parser):
image = self._augmenter.distort(image) image = self._augmenter.distort(image)
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(
offset=MEAN_RGB, image, offset=MEAN_RGB, scale=STDDEV_RGB)
scale=STDDEV_RGB)
# Random erasing after the image has been normalized # Random erasing after the image has been normalized
if self._random_erasing is not None: if self._random_erasing is not None:
...@@ -251,9 +257,8 @@ class Parser(parser.Parser): ...@@ -251,9 +257,8 @@ class Parser(parser.Parser):
image.set_shape([self._output_size[0], self._output_size[1], 3]) image.set_shape([self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(
offset=MEAN_RGB, image, offset=MEAN_RGB, scale=STDDEV_RGB)
scale=STDDEV_RGB)
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment