Commit 069bdd28 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377340715
parent 3fc55e9e
...@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig): ...@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig):
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
# Keep for backward compatibility. # Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
...@@ -66,6 +66,7 @@ class Parser(parser.Parser): ...@@ -66,6 +66,7 @@ class Parser(parser.Parser):
num_classes: float, num_classes: float,
image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = DEFAULT_LABEL_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
...@@ -78,6 +79,8 @@ class Parser(parser.Parser): ...@@ -78,6 +79,8 @@ class Parser(parser.Parser):
num_classes: `float`, number of classes. num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image in tf.Example. image_field_key: `str`, the key name to encoded image in tf.Example.
label_field_key: `str`, the key name to label in tf.Example. label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
...@@ -118,6 +121,7 @@ class Parser(parser.Parser): ...@@ -118,6 +121,7 @@ class Parser(parser.Parser):
self._augmenter = None self._augmenter = None
self._label_field_key = label_field_key self._label_field_key = label_field_key
self._is_multilabel = is_multilabel self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
def _parse_train_data(self, decoded_tensors): def _parse_train_data(self, decoded_tensors):
"""Parses data for training.""" """Parses data for training."""
...@@ -142,16 +146,29 @@ class Parser(parser.Parser): ...@@ -142,16 +146,29 @@ class Parser(parser.Parser):
def _parse_train_image(self, decoded_tensors): def _parse_train_image(self, decoded_tensors):
"""Parses image data for training.""" """Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image. if self._decode_jpeg_only:
# TODO(pengchong): support image format other than JPEG. image_shape = tf.image.extract_jpeg_shape(image_bytes)
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape) # Crops image.
image = tf.cond( cropped_image = preprocess_ops.random_crop_image_v2(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)), image_bytes, image_shape)
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape), image = tf.cond(
lambda: cropped_image) tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Crops image.
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
if self._aug_rand_hflip: if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
...@@ -159,6 +176,7 @@ class Parser(parser.Parser): ...@@ -159,6 +176,7 @@ class Parser(parser.Parser):
# Resizes image. # Resizes image.
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Apply autoaug or randaug. # Apply autoaug or randaug.
if self._augmenter is not None: if self._augmenter is not None:
...@@ -177,15 +195,23 @@ class Parser(parser.Parser): ...@@ -177,15 +195,23 @@ class Parser(parser.Parser):
def _parse_eval_image(self, decoded_tensors): def _parse_eval_image(self, decoded_tensors):
"""Parses image data for evaluation.""" """Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops and resizes image. if self._decode_jpeg_only:
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Center crops.
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
......
...@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes: ...@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes:
def create_classification_example( def create_classification_example(
image_height: int, image_height: int,
image_width: int, image_width: int,
image_format: str = 'JPEG',
is_multilabel: bool = False) -> tf.train.Example: is_multilabel: bool = False) -> tf.train.Example:
"""Creates image and labels for image classification input pipeline.""" """Creates image and labels for image classification input pipeline."""
image = _encode_image( image = _encode_image(
np.uint8(np.random.rand(image_height, image_width, 3) * 255), fmt='JPEG') np.uint8(np.random.rand(image_height, image_width, 3) * 255),
fmt=image_format)
labels = [0, 1] if is_multilabel else [0] labels = [0, 1] if is_multilabel else [0]
serialized_example = tf.train.Example( serialized_example = tf.train.Example(
features=tf.train.Features( features=tf.train.Features(
......
...@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes, num_classes=num_classes,
image_field_key=image_field_key, image_field_key=image_field_key,
label_field_key=label_field_key, label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type, aug_type=params.aug_type,
is_multilabel=is_multilabel, is_multilabel=is_multilabel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment