"vscode:/vscode.git/clone" did not exist on "b9cca5b3e4d9cfdc7b1381769893f75b0eef5afb"
Commit b7af5b2d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 367522154
parent da8a5778
......@@ -16,8 +16,10 @@
"""Common configurations."""
# Import libraries
import dataclasses
from official.core import config_definitions as cfg
from official.modeling import hyperparams
......@@ -30,7 +32,7 @@ class NormActivation(hyperparams.Config):
@dataclasses.dataclass
class PseudoLabelDataConfig(hyperparams.Config):
class PseudoLabelDataConfig(cfg.DataConfig):
"""Psuedo Label input config for training."""
input_path: str = ''
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data
......
......@@ -37,6 +37,8 @@ class DataConfig(cfg.DataConfig):
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
randaug_magnitude: Optional[int] = 10
file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
@dataclasses.dataclass
......@@ -75,6 +77,8 @@ class ImageClassificationTask(cfg.TaskConfig):
evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
model_output_keys: Optional[List[int]] = dataclasses.field(
default_factory=list)
@exp_factory.register_config_factory('image_classification')
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Classification decoder and parser."""
from typing import List, Optional
from typing import Dict, List, Optional
# Import libraries
import tensorflow as tf
......@@ -29,14 +29,16 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self):
def __init__(self,
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label'):
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/class/label': (
tf.io.FixedLenFeature((), tf.int64, default_value=-1))
image_field_key: tf.io.FixedLenFeature((), tf.string, default_value=''),
label_field_key: (tf.io.FixedLenFeature((), tf.int64, default_value=-1))
}
def decode(self, serialized_example):
def decode(self,
serialized_example: tf.train.Example) -> Dict[str, tf.Tensor]:
return tf.io.parse_single_example(
serialized_example, self._keys_to_features)
......@@ -47,6 +49,8 @@ class Parser(parser.Parser):
def __init__(self,
output_size: List[int],
num_classes: float,
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label',
aug_rand_hflip: bool = True,
aug_policy: Optional[str] = None,
randaug_magnitude: Optional[int] = 10,
......@@ -57,6 +61,8 @@ class Parser(parser.Parser):
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
image_field_key: A `str` of the key name to encoded image in TFExample.
label_field_key: A `str` of the key name to label in TFExample.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
......@@ -67,6 +73,9 @@ class Parser(parser.Parser):
self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip
self._num_classes = num_classes
self._image_field_key = image_field_key
self._label_field_key = label_field_key
if dtype == 'float32':
self._dtype = tf.float32
elif dtype == 'float16':
......@@ -89,9 +98,8 @@ class Parser(parser.Parser):
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
image_bytes = decoded_tensors['image/encoded']
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
......@@ -126,8 +134,8 @@ class Parser(parser.Parser):
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
image_bytes = decoded_tensors['image/encoded']
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops and resizes image.
......
......@@ -80,6 +80,8 @@ class ImageClassificationTask(base_task.Task):
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
image_field_key = self.task_config.train_data.image_field_key
label_field_key = self.task_config.train_data.label_field_key
if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
......@@ -88,11 +90,14 @@ class ImageClassificationTask(base_task.Task):
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder = classification_input.Decoder()
decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key)
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
aug_policy=params.aug_policy,
randaug_magnitude=params.randaug_magnitude,
dtype=params.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment