Commit 242f4098 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 477486513
parent 051f1c96
......@@ -15,7 +15,7 @@
"""Image classification configuration definition."""
import dataclasses
import os
from typing import List, Optional
from typing import List, Optional, Tuple
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -37,6 +37,7 @@ class DataConfig(cfg.DataConfig):
is_multilabel: bool = False
aug_rand_hflip: bool = True
aug_crop: Optional[bool] = True
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0)
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Classification decoder and parser."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
# Import libraries
import tensorflow as tf
......@@ -54,8 +54,8 @@ class Decoder(decoder.Decoder):
self._keys_to_features = keys_to_features
def decode(self, serialized_example):
return tf.io.parse_single_example(
serialized_example, self._keys_to_features)
return tf.io.parse_single_example(serialized_example,
self._keys_to_features)
class Parser(parser.Parser):
......@@ -73,7 +73,8 @@ class Parser(parser.Parser):
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False,
dtype: str = 'float32'):
dtype: str = 'float32',
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0)):
"""Initializes parameters for parsing annotations in the dataset.
Args:
......@@ -84,8 +85,8 @@ class Parser(parser.Parser):
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
......@@ -98,6 +99,10 @@ class Parser(parser.Parser):
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
crop_area_range: An optional `tuple` of (min_area, max_area) for image
random crop function to constraint crop operation. The cropped areas
of the image must contain a fraction of the input image within this
range. The default area range is (0.08, 1.0).
"""
self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip
......@@ -147,6 +152,7 @@ class Parser(parser.Parser):
self._random_erasing = None
self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
self._crop_area_range = crop_area_range
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
......@@ -177,7 +183,7 @@ class Parser(parser.Parser):
# Crops image.
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape)
image_bytes, image_shape, area_range=self._crop_area_range)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
......@@ -189,7 +195,8 @@ class Parser(parser.Parser):
# Crops image.
if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(image)
cropped_image = preprocess_ops.random_crop_image(
image, area_range=self._crop_area_range)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
......@@ -215,9 +222,8 @@ class Parser(parser.Parser):
image = self._augmenter.distort(image)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
image = preprocess_ops.normalize_image(
image, offset=MEAN_RGB, scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
......@@ -251,9 +257,8 @@ class Parser(parser.Parser):
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
image = preprocess_ops.normalize_image(
image, offset=MEAN_RGB, scale=STDDEV_RGB)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment