Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
......@@ -24,8 +24,8 @@ from __future__ import division
from __future__ import print_function
import math
import tensorflow.compat.v2 as tf
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union
import tensorflow as tf
from typing import Any, Dict, List, Optional, Text, Tuple
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
......@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor:
return tf.reshape(image, new_shape)
def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
"""Converts a 4D image back to `ndims` rank."""
shape = tf.shape(image)
begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
......@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return tf.reshape(image, new_shape)
def _convert_translation_to_transform(
translations: Iterable[int]) -> tf.Tensor:
def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
......@@ -122,9 +121,9 @@ def _convert_translation_to_transform(
def _convert_angles_to_transform(
angles: Union[Iterable[float], float],
image_width: int,
image_height: int) -> tf.Tensor:
angles: tf.Tensor,
image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform.
Args:
......@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
)
def transform(image: tf.Tensor,
transforms: Iterable[float]) -> tf.Tensor:
def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
......@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
return from_4d(image, original_ndims)
def translate(image: tf.Tensor,
translations: Iterable[int]) -> tf.Tensor:
def translate(image: tf.Tensor, translations) -> tf.Tensor:
"""Translates image(s) by provided vectors.
Args:
......@@ -212,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
"""
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = degrees * degrees_to_radians
radians = tf.cast(degrees * degrees_to_radians, tf.float32)
original_ndims = tf.rank(image)
image = to_4d(image)
......@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return image
def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor:
def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.vision.image_classification import augment
......@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(augment.transform(image, transforms=[1]*8),
[[4, 4], [4, 4]])
def disable_test_translate(self, dtype):
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]],
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1]],
dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image,
translations=translations)
expected = [[1, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]
expected = [
[1, 0, 1, 1],
[0, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
......@@ -133,5 +140,4 @@ class AutoaugmentTest(tf.test.TestCase):
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,18 +20,24 @@ from __future__ import division
from __future__ import print_function
import os
from typing import Any, List, MutableMapping, Text
from absl import logging
import tensorflow as tf
from typing import Any, List, MutableMapping, Text
from official.utils.misc import keras_utils
from official.vision.image_classification import optimizer_factory
def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
model_dir: Text = None) -> List[tf.keras.callbacks.Callback]:
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
......@@ -39,11 +46,29 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard:
callbacks.append(CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights))
callbacks.append(
CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights))
if time_history:
callbacks.append(
keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(
model_dir, 'average', 'model.ckpt-{epoch:04d}')
callbacks.append(AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks
......@@ -63,18 +88,19 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate
Attributes:
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard.
log_dir: the path of the directory where to save the log files to be parsed
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
**kwargs: Additional arguments for backwards compatibility. Possible key is
`period`.
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
def __init__(self,
log_dir: Text,
log_dir: str,
track_lr: bool = False,
initial_step: int = 0,
**kwargs):
......@@ -84,7 +110,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
self.step += 1
if logs is None:
logs = {}
......@@ -93,7 +119,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
......@@ -104,25 +130,24 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[Text, Any]:
def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {}
if self._track_lr:
logs['learning_rate'] = self._calculate_lr()
# TODO(b/149030439): disable LR reporting.
# if self._track_lr:
# logs['learning_rate'] = self._calculate_lr()
return logs
def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step."""
lr = self._get_base_optimizer().lr
if callable(lr):
lr = lr(self.step)
return get_scalar_from_tensor(lr)
return get_scalar_from_tensor(
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model."""
......@@ -134,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
optimizer = optimizer._optimizer # pylint:disable=protected-access
return optimizer
class MovingAverageCallback(tf.keras.callbacks.Callback):
"""A Callback to be used with a `MovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def __init__(self,
overwrite_weights_on_train_end: bool = False,
**kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimizer_factory.MovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: MutableMapping[Text, Any] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
def __init__(
self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(
filepath,
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
raise TypeError(
'AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs)
self.model.set_weights(non_avg_weights)
return result
......@@ -27,12 +27,11 @@ from typing import Any, Tuple, Text, Optional, Mapping
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.modeling import performance
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
......@@ -44,10 +43,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model
MODELS = {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
......@@ -87,19 +100,20 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy,
one_hot: bool
) -> Tuple[Any, Any, Any]:
"""Create and return train, validation, and test dataset builders."""
) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders."""
if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else:
logging.warning('label_smoothing not applied, so datasets will not be one '
'hot encoded.')
num_devices = strategy.num_replicas_in_sync
num_devices = strategy.num_replicas_in_sync if strategy else 1
image_size = get_image_size_from_model(params)
dataset_configs = [
params.train_dataset, params.validation_dataset, params.test_dataset
params.train_dataset, params.validation_dataset
]
builders = []
......@@ -120,12 +134,13 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations."""
loss_scale = params.model.loss.loss_scale
loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic':
return loss_scale
elif loss_scale is not None:
return float(loss_scale)
elif params.train_dataset.dtype == 'float32':
elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1.
else:
assert params.train_dataset.dtype == 'float16'
......@@ -145,7 +160,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'name': model,
},
'runtime': {
'enable_eager': flags_obj.enable_eager,
'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu,
},
'train_dataset': {
......@@ -154,8 +169,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset': {
'data_dir': flags_obj.data_dir,
},
'test_dataset': {
'data_dir': flags_obj.data_dir,
'train': {
'time_history': {
'log_steps': flags_obj.log_steps,
},
},
}
......@@ -169,8 +186,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs:
logging.info('Overriding params: %s', param)
# Set is_strict to false because we can have dynamic dict parameters.
params = params_dict.override_params_dict(params, param, is_strict=False)
params = params_dict.override_params_dict(params, param, is_strict=True)
params.validate()
params.lock()
......@@ -212,24 +228,21 @@ def resume_from_checkpoint(model: tf.keras.Model,
return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig):
def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(
enable_eager=params.runtime.enable_eager,
enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_threads_enabled:
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
dataset = params.train_dataset or params.validation_dataset
performance.set_mixed_precision_policy(dataset.dtype)
if dataset.data_format:
data_format = dataset.data_format
elif tf.config.list_physical_devices('GPU'):
performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
data_format = 'channels_last'
......@@ -237,7 +250,7 @@ def initialize(params: base_configs.ExperimentConfig):
distribution_utils.configure_cluster(
params.runtime.worker_hosts,
params.runtime.task_index)
if params.runtime.enable_eager:
if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True)
......@@ -254,7 +267,7 @@ def define_classifier_flags():
default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool(
'enable_eager',
'run_eagerly',
default=None,
help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string(
......@@ -265,6 +278,10 @@ def define_classifier_flags():
'dataset',
default=None,
help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig,
......@@ -291,27 +308,31 @@ def train_and_eval(
strategy_scope = distribution_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.', strategy.num_replicas_in_sync)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
label_smoothing = params.model.loss.label_smoothing
one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build() if builder else None for builder in builders]
datasets = [builder.build(strategy)
if builder else None for builder in builders]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder, test_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset, test_dataset = datasets
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset = datasets
train_epochs = params.train.epochs
train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope:
model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params)
model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate,
batch_size=train_builder.global_batch_size,
......@@ -332,7 +353,7 @@ def train_and_eval(
model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
run_eagerly=params.runtime.enable_eager)
experimental_steps_per_execution=params.train.steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
......@@ -340,15 +361,27 @@ def train_and_eval(
model_dir=params.model_dir,
train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir)
serialize_config(params=params, model_dir=params.model_dir)
# TODO(dankondratyuk): callbacks significantly slow down training
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
model_dir=params.model_dir)
if params.evaluation.skip_eval:
validation_kwargs = {}
else:
validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit(
train_dataset,
......@@ -356,15 +389,15 @@ def train_and_eval(
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
validation_data=validation_dataset,
validation_steps=validation_steps,
validation_freq=params.evaluation.epochs_between_evals)
verbose=2,
**validation_kwargs)
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
validation_output = None
if not params.evaluation.skip_eval:
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history,
validation_output,
callbacks)
......@@ -375,7 +408,7 @@ def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality."""
logging.info('Exporting model.')
model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params)
model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint
if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest '
......@@ -398,8 +431,6 @@ def run(flags_obj: flags.FlagValues,
Dictionary of training/eval stats
"""
params = _get_params_from_flags(flags_obj)
initialize(params)
if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override)
elif params.mode == 'export_only':
......@@ -409,8 +440,7 @@ def run(flags_obj: flags.FlagValues,
def main(_):
with logger.benchmark_context(flags.FLAGS):
stats = run(flags.FLAGS)
stats = run(flags.FLAGS)
if stats:
logging.info('Run stats:\n%s', stats)
......@@ -423,5 +453,4 @@ if __name__ == '__main__':
flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset')
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T
from absl import flags
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
......@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
return '--params_override=' + json.dumps(params_override)
def basic_params_override() -> MutableMapping[str, Any]:
def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing."""
return {
'train_dataset': {
......@@ -75,18 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size': True,
'batch_size': 1,
'image_size': 224,
'dtype': dtype,
},
'validation_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
},
'test_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
'dtype': dtype,
},
'train': {
'steps': 1,
......@@ -152,7 +148,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_train_and_eval_export(self, distribution, model, dataset):
def test_end_to_end_train_and_eval(self, distribution, model, dataset):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
......@@ -168,6 +164,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export')
export_params['export'] = {}
......@@ -187,6 +218,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`."""
......@@ -239,8 +305,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
model=base_configs.ModelConfig(
loss=base_configs.LossConfig(loss_scale=loss_scale)),
runtime=base_configs.RuntimeConfig(
loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
......@@ -252,19 +318,23 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
enable_eager=False,
run_eagerly=False,
enable_xla=False,
gpu_threads_enabled=True,
per_gpu_thread_count=1,
gpu_thread_mode='gpu_private',
num_gpus=1,
dataset_num_private_threads=1,
),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig(
loss=base_configs.LossConfig(loss_scale='dynamic')),
model=base_configs.ModelConfig(),
)
classifier_trainer.initialize(config)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
......@@ -313,5 +383,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(base_config.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass
class TrainConfig(base_config.Config):
"""Configuration for training.
......@@ -71,14 +82,18 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
"""
resume_checkpoint: bool = None
epochs: int = None
steps: int = None
callbacks: CallbacksConfig = CallbacksConfig()
metrics: List[str] = None
metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
steps_per_loop: int = None
@dataclasses.dataclass
......@@ -91,10 +106,12 @@ class EvalConfig(base_config.Config):
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
None.
skip_eval: Whether or not to skip evaluation.
"""
epochs_between_evals: int = None
steps: int = None
skip_eval: bool = False
@dataclasses.dataclass
......@@ -103,13 +120,11 @@ class LossConfig(base_config.Config):
Attributes:
name: The name of the loss. Defaults to None.
loss_scale: The type of loss scale
label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'.
"""
name: str = None
loss_scale: str = None
label_smoothing: float = None
......@@ -164,6 +179,7 @@ class LearningRateConfig(base_config.Config):
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
name: str = None
......@@ -175,6 +191,7 @@ class LearningRateConfig(base_config.Config):
boundaries: List[int] = None
multipliers: List[float] = None
scale_by_batch_size: float = 0.
staircase: bool = None
@dataclasses.dataclass
......@@ -190,7 +207,7 @@ class ModelConfig(base_config.Config):
"""
name: str = None
model_params: Mapping[str, Any] = None
model_params: base_config.Config = None
num_classes: int = None
loss: LossConfig = None
optimizer: OptimizerConfig = None
......@@ -216,7 +233,6 @@ class ExperimentConfig(base_config.Config):
runtime: RuntimeConfig = None
train_dataset: Any = None
validation_dataset: Any = None
test_dataset: Any = None
train: TrainConfig = None
evaluation: EvalConfig = None
model: ModelConfig = None
......
......@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
dataset_factory.ImageNetConfig(split='train')
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=500,
......@@ -54,8 +52,10 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False))
write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
......@@ -78,11 +78,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
one_hot=False,
mean_subtract=True,
standardize=True)
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
......@@ -90,8 +85,10 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False))
write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
......
......@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored'
num_gpus: 1
train_dataset:
......@@ -36,10 +34,13 @@ model:
num_classes: 1000
batch_norm: 'default'
dtype: 'float32'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
......
......@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
......@@ -35,11 +33,12 @@ model:
num_classes: 1000
batch_norm: 'tpu'
dtype: 'bfloat16'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
......
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored'
num_gpus: 1
train_dataset:
......@@ -12,6 +10,7 @@ train_dataset:
num_classes: 1000
num_examples: 1281167
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
validation_dataset:
name: 'imagenet2012'
......@@ -21,6 +20,7 @@ validation_dataset:
num_classes: 1000
num_examples: 50000
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
model:
model_params:
......@@ -29,10 +29,13 @@ model:
num_classes: 1000
batch_norm: 'default'
dtype: 'float32'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
......
......@@ -2,8 +2,6 @@
# Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
......@@ -34,10 +32,13 @@ model:
num_classes: 1000
batch_norm: 'tpu'
dtype: 'bfloat16'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
......
# Training configuration for ResNet trained on ImageNet on GPUs.
# Takes ~3 minutes, 15 seconds per epoch for 8 V100s.
# Reaches ~76.1% within 90 epochs.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored'
num_gpus: 1
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'train'
image_size: 224
num_classes: 1000
......@@ -23,7 +20,7 @@ train_dataset:
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'validation'
image_size: 224
num_classes: 1000
......@@ -34,7 +31,7 @@ validation_dataset:
mean_subtract: True
standardize: True
model:
model_name: 'resnet'
name: 'resnet'
model_params:
rescale_inputs: False
optimizer:
......
# Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~2 minutes, 43 seconds per epoch for a v3-32.
# Reaches ~76.1% within 90 epochs.
# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'train'
one_hot: False
image_size: 224
......@@ -23,7 +21,7 @@ train_dataset:
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'validation'
one_hot: False
image_size: 224
......@@ -35,7 +33,7 @@ validation_dataset:
standardize: True
dtype: 'bfloat16'
model:
model_name: 'resnet'
name: 'resnet'
model_params:
rescale_inputs: False
optimizer:
......
......@@ -23,7 +23,7 @@ import os
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import base_config
......@@ -84,11 +84,10 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores).
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
data_format: The data format of the images. Should be 'channels_last' or
'channels_first'.
dtype: The desired dtype of the dataset. This will be set during
preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
......@@ -118,9 +117,8 @@ class DatasetConfig(base_config.Config):
num_channels: Union[int, str] = 'infer'
num_examples: Union[int, str] = 'infer'
batch_size: int = 128
use_per_replica_batch_size: bool = False
use_per_replica_batch_size: bool = True
num_devices: int = 1
data_format: str = 'channels_last'
dtype: str = 'float32'
one_hot: bool = True
augmenter: AugmentConfig = AugmentConfig()
......@@ -188,14 +186,22 @@ class DatasetBuilder:
def batch_size(self) -> int:
"""The batch size, multiplied by the number of replicas (if configured)."""
if self.config.use_per_replica_batch_size:
return self.global_batch_size
return self.config.batch_size * self.config.num_devices
else:
return self.config.batch_size
@property
def global_batch_size(self):
"""The global batch size across all replicas."""
return self.config.batch_size * self.config.num_devices
return self.batch_size
@property
def local_batch_size(self):
"""The base unscaled batch size."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size
else:
return self.config.batch_size // self.config.num_devices
@property
def num_steps(self) -> int:
......@@ -203,6 +209,30 @@ class DatasetBuilder:
# Always divide by the global batch size to get the correct # of steps
return self.num_examples // self.global_batch_size
@property
def dtype(self) -> tf.dtypes.DType:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map = {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
try:
return dtype_map[self.config.dtype]
except:
raise ValueError('Invalid DType provided. Supported types: {}'.format(
dtype_map.keys()))
@property
def image_size(self) -> int:
"""The size of each image (can be inferred from the dataset)."""
......@@ -243,19 +273,42 @@ class DatasetBuilder:
self.builder_info = tfds.builder(self.config.name).info
return self.builder_info
def build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected'
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
dataset = self._build()
return dataset
def _build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it.
Args:
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
compile/fit.
cross-replica training.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
builders = {
'tfds': self.load_tfds,
'records': self.load_records,
......@@ -326,7 +379,7 @@ class DatasetBuilder:
def generate_data(_):
image = tf.zeros([self.image_size, self.image_size, self.num_channels],
dtype=self.config.dtype)
dtype=self.dtype)
label = tf.zeros([1], dtype=tf.int32)
return image, label
......@@ -345,8 +398,8 @@ class DatasetBuilder:
Args:
dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
compile/fit.
cross-replica training. If set with more than one replica, this
function assumes `use_per_replica_batch_size=True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
......@@ -366,8 +419,6 @@ class DatasetBuilder:
cycle_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(self.global_batch_size)
if self.config.cache:
dataset = dataset.cache()
......@@ -383,13 +434,25 @@ class DatasetBuilder:
dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training)
# Note: we could do image normalization here, but we defer it to the model
# which can perform it much faster on a GPU/TPU
# TODO(dankondratyuk): if we fix prefetching, we can do it here
if input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
raise ValueError(
'The builder does not support a global batch size with more than '
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'.format(
self.config.num_devices))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size,
drop_remainder=self.is_training)
else:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
if self.is_training and self.config.deterministic_train is not None:
if self.is_training:
options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack
......@@ -400,9 +463,7 @@ class DatasetBuilder:
dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks.
# Instead, use a constant prefetch size like the the number of devices.
dataset = dataset.prefetch(self.config.num_devices)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......@@ -451,7 +512,7 @@ class DatasetBuilder:
image_size=self.image_size,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.config.dtype,
dtype=self.dtype,
augmenter=self.augmenter)
else:
image = preprocessing.preprocess_for_eval(
......@@ -460,7 +521,7 @@ class DatasetBuilder:
num_channels=self.num_channels,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.config.dtype)
dtype=self.dtype)
label = tf.cast(label, tf.int32)
if self.config.one_hot:
......
......@@ -19,15 +19,14 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from typing import Text, Optional
from tensorflow.python.tpu import tpu_function
@tf.keras.utils.register_keras_serializable(package='Text')
@tf.keras.utils.register_keras_serializable(package='Vision')
class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
"""Cross replica batch normalization."""
......@@ -98,3 +97,21 @@ def count_params(model, trainable_only=True):
else:
return int(np.sum([tf.keras.backend.count_params(p)
for p in model.trainable_weights]))
def load_weights(model: tf.keras.Model,
model_weights_path: Text,
weights_format: Text = 'saved_model'):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
"""
if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path)
model.set_weights(loaded_model.get_weights())
else:
model.load_weights(model_weights_path)
......@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs
......@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
configuration.
learning_rate: The configuration for learning rate. Defaults to an
exponential configuration.
"""
name: str = 'EfficientNet'
num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {
'model_name': 'efficientnet-b0',
'model_weights_path': '',
'copy_to_local': False,
'overrides': {
'batch_norm': 'default',
'rescale_input': True,
'num_classes': 1000,
}
})
model_params: base_config.Config = dataclasses.field(
default_factory=lambda: {
'model_name': 'efficientnet-b0',
'model_weights_path': '',
'weights_format': 'saved_model',
'overrides': {
'batch_norm': 'default',
'rescale_input': True,
'num_classes': 1000,
'activation': 'swish',
'dtype': 'float32',
}
})
loss: base_configs.LossConfig = base_configs.LossConfig(
name='categorical_crossentropy',
label_smoothing=0.1)
name='categorical_crossentropy', label_smoothing=0.1)
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='rmsprop',
decay=0.9,
......@@ -72,4 +74,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
decay_epochs=2.4,
decay_rate=0.97,
warmup_epochs=5,
scale_by_batch_size=1. / 128.)
scale_by_batch_size=1. / 128.,
staircase=True)
......@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple
from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
......@@ -104,6 +104,8 @@ MODEL_CONFIGS = {
'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4),
'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5),
'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5),
'efficientnet-b8': ModelConfig.from_args(2.2, 3.6, 672, 0.5),
'efficientnet-l2': ModelConfig.from_args(4.3, 5.3, 800, 0.5),
}
CONV_KERNEL_INITIALIZER = {
......@@ -166,7 +168,7 @@ def conv2d_block(inputs: tf.Tensor,
batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum
bn_epsilon = config.bn_epsilon
data_format = config.data_format
data_format = tf.keras.backend.image_data_format()
weight_decay = config.weight_decay
name = name or ''
......@@ -223,7 +225,7 @@ def mb_conv_block(inputs: tf.Tensor,
use_se = config.use_se
activation = tf_utils.get_activation(config.activation)
drop_connect_rate = config.drop_connect_rate
data_format = config.data_format
data_format = tf.keras.backend.image_data_format()
use_depthwise = block.conv_type != 'no_depthwise'
prefix = prefix or ''
......@@ -346,12 +348,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
num_classes = config.num_classes
input_channels = config.input_channels
rescale_input = config.rescale_input
data_format = config.data_format
data_format = tf.keras.backend.image_data_format()
dtype = config.dtype
weight_decay = config.weight_decay
x = image_input
if data_format == 'channels_first':
# Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input:
x = preprocessing.normalize_images(x,
num_channels=input_channels,
......@@ -463,7 +467,7 @@ class EfficientNet(tf.keras.Model):
def from_name(cls,
model_name: Text,
model_weights_path: Text = None,
copy_to_local: bool = False,
weights_format: Text = 'saved_model',
overrides: Dict[Text, Any] = None):
"""Construct an EfficientNet model from a predefined model name.
......@@ -472,7 +476,8 @@ class EfficientNet(tf.keras.Model):
Args:
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
overrides: (optional) a dict containing keys that can override config
Returns:
......@@ -492,12 +497,8 @@ class EfficientNet(tf.keras.Model):
model = cls(config=config, overrides=overrides)
if model_weights_path:
if copy_to_local:
tmp_file = os.path.join('/tmp', model_name + '.h5')
model_weights_file = os.path.join(model_weights_path, 'model.h5')
tf.io.gfile.copy(model_weights_file, tmp_file, overwrite=True)
model_weights_path = tmp_file
model.load_weights(model_weights_path)
common_modules.load_weights(model,
model_weights_path,
weights_format=weights_format)
return model
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS = flags.FLAGS
flags.DEFINE_string("model_name", None,
"EfficientNet model name.")
flags.DEFINE_string("model_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path to export.")
def export_tfhub(model_path, hub_destination, model_name):
"""Restores a tf.keras.Model and saves for TF-Hub."""
model = efficientnet_model.EfficientNet.from_name(model_name)
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(model_path).assert_existing_objects_matched()
image_input = tf.keras.layers.Input(
shape=(None, None, 3), name="image_input", dtype=tf.float32)
x = image_input * 255.0
ouputs = model(x)
hub_model = tf.keras.Model(image_input, ouputs)
# Exports a SavedModel.
hub_model.save(
os.path.join(hub_destination, "classification"), include_optimizer=False)
feature_vector_output = hub_model.get_layer(name="efficientnet").get_layer(
name="top_pool").get_output_at(0)
hub_model2 = tf.keras.Model(model.inputs, feature_vector_output)
# Exports a SavedModel.
hub_model2.save(
os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__":
app.run(main)
......@@ -20,7 +20,7 @@ from __future__ import print_function
from typing import Any, List, Mapping
import tensorflow.compat.v2 as tf
import tensorflow as tf
BASE_LEARNING_RATE = 0.1
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.vision.image_classification import learning_rate
......@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment