Commit 069bdd28 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377340715
parent 3fc55e9e
......@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig):
file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
......@@ -66,6 +66,7 @@ class Parser(parser.Parser):
num_classes: float,
image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None,
is_multilabel: bool = False,
......@@ -78,6 +79,8 @@ class Parser(parser.Parser):
num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image in tf.Example.
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
......@@ -118,6 +121,7 @@ class Parser(parser.Parser):
self._augmenter = None
self._label_field_key = label_field_key
self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
......@@ -142,16 +146,29 @@ class Parser(parser.Parser):
def _parse_train_image(self, decoded_tensors):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
# TODO(pengchong): support image format other than JPEG.
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
if self._decode_jpeg_only:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Crops image.
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
......@@ -159,6 +176,7 @@ class Parser(parser.Parser):
# Resizes image.
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Apply autoaug or randaug.
if self._augmenter is not None:
......@@ -177,15 +195,23 @@ class Parser(parser.Parser):
def _parse_eval_image(self, decoded_tensors):
"""Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops and resizes image.
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
if self._decode_jpeg_only:
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Center crops.
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
......
......@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes:
def create_classification_example(
image_height: int,
image_width: int,
image_format: str = 'JPEG',
is_multilabel: bool = False) -> tf.train.Example:
"""Creates image and labels for image classification input pipeline."""
image = _encode_image(
np.uint8(np.random.rand(image_height, image_width, 3) * 255), fmt='JPEG')
np.uint8(np.random.rand(image_height, image_width, 3) * 255),
fmt=image_format)
labels = [0, 1] if is_multilabel else [0]
serialized_example = tf.train.Example(
features=tf.train.Features(
......
......@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
is_multilabel=is_multilabel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment