Commit 0bcf460a authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 372996541
parent ebf268b6
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Classification decoder and parser.""" """Classification decoder and parser."""
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -26,27 +26,34 @@ from official.vision.beta.ops import preprocess_ops ...@@ -26,27 +26,34 @@ from official.vision.beta.ops import preprocess_ops
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
DEFAULT_IMAGE_FIELD_KEY = 'image/encoded'
DEFAULT_LABEL_FIELD_KEY = 'image/class/label'
class Decoder(decoder.Decoder): class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task.""" """A tf.Example decoder for classification task."""
def __init__(self, def __init__(self,
image_field_key: str = 'image/encoded', image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = 'image/class/label', label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
is_multilabel: bool = False): is_multilabel: bool = False,
self._keys_to_features = { keys_to_features: Optional[Dict[str, Any]] = None):
image_field_key: tf.io.FixedLenFeature((), tf.string, default_value=''), if not keys_to_features:
} keys_to_features = {
if is_multilabel: image_field_key:
self._keys_to_features.update( tf.io.FixedLenFeature((), tf.string, default_value=''),
{label_field_key: tf.io.VarLenFeature(dtype=tf.int64)}) }
else: if is_multilabel:
self._keys_to_features.update({ keys_to_features.update(
label_field_key: tf.io.FixedLenFeature((), tf.int64, default_value=-1) {label_field_key: tf.io.VarLenFeature(dtype=tf.int64)})
}) else:
keys_to_features.update({
label_field_key:
tf.io.FixedLenFeature((), tf.int64, default_value=-1)
})
self._keys_to_features = keys_to_features
def decode(self, def decode(self, serialized_example):
serialized_example: tf.train.Example) -> Dict[str, tf.Tensor]:
return tf.io.parse_single_example( return tf.io.parse_single_example(
serialized_example, self._keys_to_features) serialized_example, self._keys_to_features)
...@@ -57,8 +64,8 @@ class Parser(parser.Parser): ...@@ -57,8 +64,8 @@ class Parser(parser.Parser):
def __init__(self, def __init__(self,
output_size: List[int], output_size: List[int],
num_classes: float, num_classes: float,
image_field_key: str = 'image/encoded', image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = 'image/class/label', label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
...@@ -69,8 +76,8 @@ class Parser(parser.Parser): ...@@ -69,8 +76,8 @@ class Parser(parser.Parser):
output_size: `Tensor` or `list` for [height, width] of output image. The output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes. num_classes: `float`, number of classes.
image_field_key: A `str` of the key name to encoded image in TFExample. image_field_key: `str`, the key name to encoded image in tf.Example.
label_field_key: A `str` of the key name to label in TFExample. label_field_key: `str`, the key name to label in tf.Example.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
...@@ -83,9 +90,6 @@ class Parser(parser.Parser): ...@@ -83,9 +90,6 @@ class Parser(parser.Parser):
self._aug_rand_hflip = aug_rand_hflip self._aug_rand_hflip = aug_rand_hflip
self._num_classes = num_classes self._num_classes = num_classes
self._image_field_key = image_field_key self._image_field_key = image_field_key
self._label_field_key = label_field_key
self._is_multilabel = is_multilabel
if dtype == 'float32': if dtype == 'float32':
self._dtype = tf.float32 self._dtype = tf.float32
elif dtype == 'float16': elif dtype == 'float16':
...@@ -111,10 +115,31 @@ class Parser(parser.Parser): ...@@ -111,10 +115,31 @@ class Parser(parser.Parser):
aug_type.type)) aug_type.type))
else: else:
self._augmenter = None self._augmenter = None
self._label_field_key = label_field_key
self._is_multilabel = is_multilabel
def _parse_train_data(self, decoded_tensors): def _parse_train_data(self, decoded_tensors):
"""Parses data for training.""" """Parses data for training."""
image = self._parse_train_image(decoded_tensors)
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
image = self._parse_eval_image(decoded_tensors)
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32) label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_train_image(self, decoded_tensors):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
...@@ -146,16 +171,10 @@ class Parser(parser.Parser): ...@@ -146,16 +171,10 @@ class Parser(parser.Parser):
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
if self._is_multilabel: return image
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_eval_data(self, decoded_tensors): def _parse_eval_image(self, decoded_tensors):
"""Parses data for evaluation.""" """Parses image data for evaluation."""
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes) image_shape = tf.image.extract_jpeg_shape(image_bytes)
...@@ -175,9 +194,4 @@ class Parser(parser.Parser): ...@@ -175,9 +194,4 @@ class Parser(parser.Parser):
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
if self._is_multilabel: return image
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment