Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
...@@ -24,8 +24,8 @@ from __future__ import division ...@@ -24,8 +24,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import tensorflow.compat.v2 as tf import tensorflow as tf
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union from typing import Any, Dict, List, Optional, Text, Tuple
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
...@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor: ...@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor:
return tf.reshape(image, new_shape) return tf.reshape(image, new_shape)
def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor: def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
"""Converts a 4D image back to `ndims` rank.""" """Converts a 4D image back to `ndims` rank."""
shape = tf.shape(image) shape = tf.shape(image)
begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32) begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
...@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor: ...@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return tf.reshape(image, new_shape) return tf.reshape(image, new_shape)
def _convert_translation_to_transform( def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
translations: Iterable[int]) -> tf.Tensor:
"""Converts translations to a projective transform. """Converts translations to a projective transform.
The translation matrix looks like this: The translation matrix looks like this:
...@@ -122,9 +121,9 @@ def _convert_translation_to_transform( ...@@ -122,9 +121,9 @@ def _convert_translation_to_transform(
def _convert_angles_to_transform( def _convert_angles_to_transform(
angles: Union[Iterable[float], float], angles: tf.Tensor,
image_width: int, image_width: tf.Tensor,
image_height: int) -> tf.Tensor: image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform. """Converts an angle or angles to a projective transform.
Args: Args:
...@@ -166,8 +165,7 @@ def _convert_angles_to_transform( ...@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
) )
def transform(image: tf.Tensor, def transform(image: tf.Tensor, transforms) -> tf.Tensor:
transforms: Iterable[float]) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`.""" """Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image) original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32) transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
...@@ -181,8 +179,7 @@ def transform(image: tf.Tensor, ...@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
return from_4d(image, original_ndims) return from_4d(image, original_ndims)
def translate(image: tf.Tensor, def translate(image: tf.Tensor, translations) -> tf.Tensor:
translations: Iterable[int]) -> tf.Tensor:
"""Translates image(s) by provided vectors. """Translates image(s) by provided vectors.
Args: Args:
...@@ -212,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor: ...@@ -212,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
""" """
# Convert from degrees to radians. # Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0 degrees_to_radians = math.pi / 180.0
radians = degrees * degrees_to_radians radians = tf.cast(degrees * degrees_to_radians, tf.float32)
original_ndims = tf.rank(image) original_ndims = tf.rank(image)
image = to_4d(image) image = to_4d(image)
...@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor: ...@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return image return image
def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor: def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative.""" """With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool) should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow.compat.v2 as tf import tensorflow as tf
from official.vision.image_classification import augment from official.vision.image_classification import augment
...@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(augment.transform(image, transforms=[1]*8), self.assertAllEqual(augment.transform(image, transforms=[1]*8),
[[4, 4], [4, 4]]) [[4, 4], [4, 4]])
def disable_test_translate(self, dtype): def test_translate(self, dtype):
image = tf.constant( image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], [[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1]],
dtype=dtype) dtype=dtype)
translations = [-1, -1] translations = [-1, -1]
translated = augment.translate(image=image, translated = augment.translate(image=image,
translations=translations) translations=translations)
expected = [[1, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]] expected = [
[1, 0, 1, 1],
[0, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]]
self.assertAllEqual(translated, expected) self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype): def test_translate_shapes(self, dtype):
...@@ -133,5 +140,4 @@ class AutoaugmentTest(tf.test.TestCase): ...@@ -133,5 +140,4 @@ class AutoaugmentTest(tf.test.TestCase):
self.assertEqual((224, 224, 3), image.shape) self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -19,18 +20,24 @@ from __future__ import division ...@@ -19,18 +20,24 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from typing import Any, List, MutableMapping, Text
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Any, List, MutableMapping, Text
from official.utils.misc import keras_utils
from official.vision.image_classification import optimizer_factory
def get_callbacks(model_checkpoint: bool = True, def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True, include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True, track_lr: bool = True,
write_model_weights: bool = True, write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0, initial_step: int = 0,
model_dir: Text = None) -> List[tf.keras.callbacks.Callback]: batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks.""" """Get all callbacks."""
model_dir = model_dir or '' model_dir = model_dir or ''
callbacks = [] callbacks = []
...@@ -39,11 +46,29 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -39,11 +46,29 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks.append(tf.keras.callbacks.ModelCheckpoint( callbacks.append(tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1)) ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard: if include_tensorboard:
callbacks.append(CustomTensorBoard( callbacks.append(
log_dir=model_dir, CustomTensorBoard(
track_lr=track_lr, log_dir=model_dir,
initial_step=initial_step, track_lr=track_lr,
write_images=write_model_weights)) initial_step=initial_step,
write_images=write_model_weights))
if time_history:
callbacks.append(
keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(
model_dir, 'average', 'model.ckpt-{epoch:04d}')
callbacks.append(AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks return callbacks
...@@ -63,18 +88,19 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -63,18 +88,19 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate - Global learning rate
Attributes: Attributes:
log_dir: the path of the directory where to save the log files to be log_dir: the path of the directory where to save the log files to be parsed
parsed by TensorBoard. by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate. track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery. initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key **kwargs: Additional arguments for backwards compatibility. Possible key is
is `period`. `period`.
""" """
# TODO(b/146499062): track params, flops, log lr, l2 loss, # TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss # classification loss
def __init__(self, def __init__(self,
log_dir: Text, log_dir: str,
track_lr: bool = False, track_lr: bool = False,
initial_step: int = 0, initial_step: int = 0,
**kwargs): **kwargs):
...@@ -84,7 +110,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -84,7 +110,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self, def on_batch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
self.step += 1 self.step += 1
if logs is None: if logs is None:
logs = {} logs = {}
...@@ -93,7 +119,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -93,7 +119,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self, def on_epoch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
...@@ -104,25 +130,24 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -104,25 +130,24 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self, def on_epoch_end(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
logs.update(metrics) logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs) super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[Text, Any]: def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {} logs = {}
if self._track_lr: # TODO(b/149030439): disable LR reporting.
logs['learning_rate'] = self._calculate_lr() # if self._track_lr:
# logs['learning_rate'] = self._calculate_lr()
return logs return logs
def _calculate_lr(self) -> int: def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step.""" """Calculates the learning rate given the current step."""
lr = self._get_base_optimizer().lr return get_scalar_from_tensor(
if callable(lr): self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
lr = lr(self.step)
return get_scalar_from_tensor(lr)
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer: def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model.""" """Get the base optimizer used by the current model."""
...@@ -134,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -134,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
optimizer = optimizer._optimizer # pylint:disable=protected-access optimizer = optimizer._optimizer # pylint:disable=protected-access
return optimizer return optimizer
class MovingAverageCallback(tf.keras.callbacks.Callback):
"""A Callback to be used with a `MovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def __init__(self,
overwrite_weights_on_train_end: bool = False,
**kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimizer_factory.MovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: MutableMapping[Text, Any] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
def __init__(
self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(
filepath,
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
raise TypeError(
'AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs)
self.model.set_weights(non_avg_weights)
return result
...@@ -27,12 +27,11 @@ from typing import Any, Tuple, Text, Optional, Mapping ...@@ -27,12 +27,11 @@ from typing import Any, Tuple, Text, Optional, Mapping
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow.compat.v2 as tf import tensorflow as tf
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks from official.vision.image_classification import callbacks as custom_callbacks
...@@ -44,10 +43,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model ...@@ -44,10 +43,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
MODELS = {
'efficientnet': efficientnet_model.EfficientNet.from_name, def get_models() -> Mapping[str, tf.keras.Model]:
'resnet': resnet_model.resnet50, """Returns the mapping from model type name to Keras model."""
} return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]: def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
...@@ -87,19 +100,20 @@ def get_image_size_from_model( ...@@ -87,19 +100,20 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig, def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy, strategy: tf.distribute.Strategy,
one_hot: bool one_hot: bool
) -> Tuple[Any, Any, Any]: ) -> Tuple[Any, Any]:
"""Create and return train, validation, and test dataset builders.""" """Create and return train and validation dataset builders."""
if one_hot: if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.') logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else: else:
logging.warning('label_smoothing not applied, so datasets will not be one ' logging.warning('label_smoothing not applied, so datasets will not be one '
'hot encoded.') 'hot encoded.')
num_devices = strategy.num_replicas_in_sync num_devices = strategy.num_replicas_in_sync if strategy else 1
image_size = get_image_size_from_model(params) image_size = get_image_size_from_model(params)
dataset_configs = [ dataset_configs = [
params.train_dataset, params.validation_dataset, params.test_dataset params.train_dataset, params.validation_dataset
] ]
builders = [] builders = []
...@@ -120,12 +134,13 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig, ...@@ -120,12 +134,13 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
def get_loss_scale(params: base_configs.ExperimentConfig, def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float: fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations.""" """Returns the loss scale for initializations."""
loss_scale = params.model.loss.loss_scale loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic': if loss_scale == 'dynamic':
return loss_scale return loss_scale
elif loss_scale is not None: elif loss_scale is not None:
return float(loss_scale) return float(loss_scale)
elif params.train_dataset.dtype == 'float32': elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1. return 1.
else: else:
assert params.train_dataset.dtype == 'float16' assert params.train_dataset.dtype == 'float16'
...@@ -145,7 +160,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -145,7 +160,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'name': model, 'name': model,
}, },
'runtime': { 'runtime': {
'enable_eager': flags_obj.enable_eager, 'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu, 'tpu': flags_obj.tpu,
}, },
'train_dataset': { 'train_dataset': {
...@@ -154,8 +169,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -154,8 +169,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset': { 'validation_dataset': {
'data_dir': flags_obj.data_dir, 'data_dir': flags_obj.data_dir,
}, },
'test_dataset': { 'train': {
'data_dir': flags_obj.data_dir, 'time_history': {
'log_steps': flags_obj.log_steps,
},
}, },
} }
...@@ -169,8 +186,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -169,8 +186,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs: for param in overriding_configs:
logging.info('Overriding params: %s', param) logging.info('Overriding params: %s', param)
# Set is_strict to false because we can have dynamic dict parameters. params = params_dict.override_params_dict(params, param, is_strict=True)
params = params_dict.override_params_dict(params, param, is_strict=False)
params.validate() params.validate()
params.lock() params.lock()
...@@ -212,24 +228,21 @@ def resume_from_checkpoint(model: tf.keras.Model, ...@@ -212,24 +228,21 @@ def resume_from_checkpoint(model: tf.keras.Model,
return int(initial_epoch) return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig): def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=params.runtime.enable_eager,
enable_xla=params.runtime.enable_xla) enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_threads_enabled: if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count( keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count, per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode, gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) datasets_num_private_threads=params.runtime.dataset_num_private_threads)
dataset = params.train_dataset or params.validation_dataset performance.set_mixed_precision_policy(dataset_builder.dtype,
performance.set_mixed_precision_policy(dataset.dtype) get_loss_scale(params))
if tf.config.list_physical_devices('GPU'):
if dataset.data_format:
data_format = dataset.data_format
elif tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
data_format = 'channels_last' data_format = 'channels_last'
...@@ -237,7 +250,7 @@ def initialize(params: base_configs.ExperimentConfig): ...@@ -237,7 +250,7 @@ def initialize(params: base_configs.ExperimentConfig):
distribution_utils.configure_cluster( distribution_utils.configure_cluster(
params.runtime.worker_hosts, params.runtime.worker_hosts,
params.runtime.task_index) params.runtime.task_index)
if params.runtime.enable_eager: if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging # Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True) tf.config.experimental_run_functions_eagerly(True)
...@@ -254,7 +267,7 @@ def define_classifier_flags(): ...@@ -254,7 +267,7 @@ def define_classifier_flags():
default=None, default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.') help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool( flags.DEFINE_bool(
'enable_eager', 'run_eagerly',
default=None, default=None,
help='Use eager execution and disable autograph for debugging.') help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string( flags.DEFINE_string(
...@@ -265,6 +278,10 @@ def define_classifier_flags(): ...@@ -265,6 +278,10 @@ def define_classifier_flags():
'dataset', 'dataset',
default=None, default=None,
help='The name of the dataset, e.g. ImageNet, etc.') help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig, def serialize_config(params: base_configs.ExperimentConfig,
...@@ -291,27 +308,31 @@ def train_and_eval( ...@@ -291,27 +308,31 @@ def train_and_eval(
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.', strategy.num_replicas_in_sync) logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
label_smoothing = params.model.loss.label_smoothing label_smoothing = params.model.loss.label_smoothing
one_hot = label_smoothing and label_smoothing > 0 one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot) builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build() if builder else None for builder in builders] datasets = [builder.build(strategy)
if builder else None for builder in builders]
# Unpack datasets and builders based on train/val/test splits # Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder, test_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset, test_dataset = datasets train_dataset, validation_dataset = datasets
train_epochs = params.train.epochs train_epochs = params.train.epochs
train_steps = params.train.steps or train_builder.num_steps train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size) logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope: with strategy_scope:
model_params = params.model.model_params.as_dict() model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params) model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate( learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate, params=params.model.learning_rate,
batch_size=train_builder.global_batch_size, batch_size=train_builder.global_batch_size,
...@@ -332,7 +353,7 @@ def train_and_eval( ...@@ -332,7 +353,7 @@ def train_and_eval(
model.compile(optimizer=optimizer, model.compile(optimizer=optimizer,
loss=loss_obj, loss=loss_obj,
metrics=metrics, metrics=metrics,
run_eagerly=params.runtime.enable_eager) experimental_steps_per_execution=params.train.steps_per_loop)
initial_epoch = 0 initial_epoch = 0
if params.train.resume_checkpoint: if params.train.resume_checkpoint:
...@@ -340,15 +361,27 @@ def train_and_eval( ...@@ -340,15 +361,27 @@ def train_and_eval(
model_dir=params.model_dir, model_dir=params.model_dir,
train_steps=train_steps) train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir)
serialize_config(params=params, model_dir=params.model_dir) serialize_config(params=params, model_dir=params.model_dir)
# TODO(dankondratyuk): callbacks significantly slow down training
callbacks = custom_callbacks.get_callbacks( if params.evaluation.skip_eval:
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, validation_kwargs = {}
include_tensorboard=params.train.callbacks.enable_tensorboard, else:
track_lr=params.train.tensorboard.track_lr, validation_kwargs = {
write_model_weights=params.train.tensorboard.write_model_weights, 'validation_data': validation_dataset,
initial_step=initial_epoch * train_steps, 'validation_steps': validation_steps,
model_dir=params.model_dir) 'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit( history = model.fit(
train_dataset, train_dataset,
...@@ -356,15 +389,15 @@ def train_and_eval( ...@@ -356,15 +389,15 @@ def train_and_eval(
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
initial_epoch=initial_epoch, initial_epoch=initial_epoch,
callbacks=callbacks, callbacks=callbacks,
validation_data=validation_dataset, verbose=2,
validation_steps=validation_steps, **validation_kwargs)
validation_freq=params.evaluation.epochs_between_evals)
validation_output = model.evaluate( validation_output = None
validation_dataset, steps=validation_steps, verbose=2) if not params.evaluation.skip_eval:
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy # TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history, stats = common.build_stats(history,
validation_output, validation_output,
callbacks) callbacks)
...@@ -375,7 +408,7 @@ def export(params: base_configs.ExperimentConfig): ...@@ -375,7 +408,7 @@ def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality.""" """Runs the model export functionality."""
logging.info('Exporting model.') logging.info('Exporting model.')
model_params = params.model.model_params.as_dict() model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params) model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint checkpoint = params.export.checkpoint
if checkpoint is None: if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest ' logging.info('No export checkpoint was provided. Using the latest '
...@@ -398,8 +431,6 @@ def run(flags_obj: flags.FlagValues, ...@@ -398,8 +431,6 @@ def run(flags_obj: flags.FlagValues,
Dictionary of training/eval stats Dictionary of training/eval stats
""" """
params = _get_params_from_flags(flags_obj) params = _get_params_from_flags(flags_obj)
initialize(params)
if params.mode == 'train_and_eval': if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override) return train_and_eval(params, strategy_override)
elif params.mode == 'export_only': elif params.mode == 'export_only':
...@@ -409,8 +440,7 @@ def run(flags_obj: flags.FlagValues, ...@@ -409,8 +440,7 @@ def run(flags_obj: flags.FlagValues,
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): stats = run(flags.FLAGS)
stats = run(flags.FLAGS)
if stats: if stats:
logging.info('Run stats:\n%s', stats) logging.info('Run stats:\n%s', stats)
...@@ -423,5 +453,4 @@ if __name__ == '__main__': ...@@ -423,5 +453,4 @@ if __name__ == '__main__':
flags.mark_flag_as_required('model_type') flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset') flags.mark_flag_as_required('dataset')
assert tf.version.VERSION.startswith('2.')
app.run(main) app.run(main)
...@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T ...@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T
from absl import flags from absl import flags
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow.compat.v2 as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
...@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str: ...@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
return '--params_override=' + json.dumps(params_override) return '--params_override=' + json.dumps(params_override)
def basic_params_override() -> MutableMapping[str, Any]: def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing.""" """Returns a basic parameter configuration for testing."""
return { return {
'train_dataset': { 'train_dataset': {
...@@ -75,18 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]: ...@@ -75,18 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size': True, 'use_per_replica_batch_size': True,
'batch_size': 1, 'batch_size': 1,
'image_size': 224, 'image_size': 224,
'dtype': dtype,
}, },
'validation_dataset': { 'validation_dataset': {
'builder': 'synthetic', 'builder': 'synthetic',
'batch_size': 1, 'batch_size': 1,
'use_per_replica_batch_size': True, 'use_per_replica_batch_size': True,
'image_size': 224, 'image_size': 224,
}, 'dtype': dtype,
'test_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
}, },
'train': { 'train': {
'steps': 1, 'steps': 1,
...@@ -152,7 +148,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -152,7 +148,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
tf.io.gfile.rmtree(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
def test_end_to_end_train_and_eval_export(self, distribution, model, dataset): def test_end_to_end_train_and_eval(self, distribution, model, dataset):
"""Test train_and_eval and export for Keras classifier models.""" """Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run # Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use # classifier_train.py --batch_size=...) by design, so use
...@@ -168,6 +164,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -168,6 +164,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval', '--mode=train_and_eval',
] ]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override() export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export') export_path = os.path.join(model_dir, 'export')
export_params['export'] = {} export_params['export'] = {}
...@@ -187,6 +218,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -187,6 +218,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir) model_dir=model_dir)
self.assertTrue(os.path.exists(export_path)) self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset): def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`.""" """Test the Keras EfficientNet model with `strategy`."""
...@@ -239,8 +305,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -239,8 +305,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
) )
def test_get_loss_scale(self, loss_scale, dtype, expected): def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
model=base_configs.ModelConfig( runtime=base_configs.RuntimeConfig(
loss=base_configs.LossConfig(loss_scale=loss_scale)), loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype)) train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128) ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected) self.assertEqual(ls, expected)
...@@ -252,19 +318,23 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -252,19 +318,23 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_initialize(self, dtype): def test_initialize(self, dtype):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig( runtime=base_configs.RuntimeConfig(
enable_eager=False, run_eagerly=False,
enable_xla=False, enable_xla=False,
gpu_threads_enabled=True,
per_gpu_thread_count=1, per_gpu_thread_count=1,
gpu_thread_mode='gpu_private', gpu_thread_mode='gpu_private',
num_gpus=1, num_gpus=1,
dataset_num_private_threads=1, dataset_num_private_threads=1,
), ),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype), train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig( model=base_configs.ModelConfig(),
loss=base_configs.LossConfig(loss_scale='dynamic')),
) )
classifier_trainer.initialize(config)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self): def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint.""" """Tests functionality for resuming from checkpoint."""
...@@ -313,5 +383,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -313,5 +383,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
tf.io.gfile.rmtree(model_dir) tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config): ...@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
top_5: bool = None top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(base_config.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass @dataclasses.dataclass
class TrainConfig(base_config.Config): class TrainConfig(base_config.Config):
"""Configuration for training. """Configuration for training.
...@@ -71,14 +82,18 @@ class TrainConfig(base_config.Config): ...@@ -71,14 +82,18 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig. callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig. metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig. tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
""" """
resume_checkpoint: bool = None resume_checkpoint: bool = None
epochs: int = None epochs: int = None
steps: int = None steps: int = None
callbacks: CallbacksConfig = CallbacksConfig() callbacks: CallbacksConfig = CallbacksConfig()
metrics: List[str] = None metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig() tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
steps_per_loop: int = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -91,10 +106,12 @@ class EvalConfig(base_config.Config): ...@@ -91,10 +106,12 @@ class EvalConfig(base_config.Config):
steps: The number of eval steps to run during evaluation. If None, this will steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to be inferred based on the number of images and batch size. Defaults to
None. None.
skip_eval: Whether or not to skip evaluation.
""" """
epochs_between_evals: int = None epochs_between_evals: int = None
steps: int = None steps: int = None
skip_eval: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -103,13 +120,11 @@ class LossConfig(base_config.Config): ...@@ -103,13 +120,11 @@ class LossConfig(base_config.Config):
Attributes: Attributes:
name: The name of the loss. Defaults to None. name: The name of the loss. Defaults to None.
loss_scale: The type of loss scale
label_smoothing: Whether or not to apply label smoothing to the loss. This label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'. only applies to 'categorical_cross_entropy'.
""" """
name: str = None name: str = None
loss_scale: str = None
label_smoothing: float = None label_smoothing: float = None
...@@ -164,6 +179,7 @@ class LearningRateConfig(base_config.Config): ...@@ -164,6 +179,7 @@ class LearningRateConfig(base_config.Config):
multipliers: multipliers used in piecewise constant decay with warmup. multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default). size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
""" """
name: str = None name: str = None
...@@ -175,6 +191,7 @@ class LearningRateConfig(base_config.Config): ...@@ -175,6 +191,7 @@ class LearningRateConfig(base_config.Config):
boundaries: List[int] = None boundaries: List[int] = None
multipliers: List[float] = None multipliers: List[float] = None
scale_by_batch_size: float = 0. scale_by_batch_size: float = 0.
staircase: bool = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -190,7 +207,7 @@ class ModelConfig(base_config.Config): ...@@ -190,7 +207,7 @@ class ModelConfig(base_config.Config):
""" """
name: str = None name: str = None
model_params: Mapping[str, Any] = None model_params: base_config.Config = None
num_classes: int = None num_classes: int = None
loss: LossConfig = None loss: LossConfig = None
optimizer: OptimizerConfig = None optimizer: OptimizerConfig = None
...@@ -216,7 +233,6 @@ class ExperimentConfig(base_config.Config): ...@@ -216,7 +233,6 @@ class ExperimentConfig(base_config.Config):
runtime: RuntimeConfig = None runtime: RuntimeConfig = None
train_dataset: Any = None train_dataset: Any = None
validation_dataset: Any = None validation_dataset: Any = None
test_dataset: Any = None
train: TrainConfig = None train: TrainConfig = None
evaluation: EvalConfig = None evaluation: EvalConfig = None
model: ModelConfig = None model: ModelConfig = None
......
...@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
dataset_factory.ImageNetConfig(split='train') dataset_factory.ImageNetConfig(split='train')
validation_dataset: dataset_factory.DatasetConfig = \ validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation') dataset_factory.ImageNetConfig(split='validation')
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
train: base_configs.TrainConfig = base_configs.TrainConfig( train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True, resume_checkpoint=True,
epochs=500, epochs=500,
...@@ -54,8 +52,10 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -54,8 +52,10 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True), enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
...@@ -78,11 +78,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -78,11 +78,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
one_hot=False, one_hot=False,
mean_subtract=True, mean_subtract=True,
standardize=True) standardize=True)
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig( train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True, resume_checkpoint=True,
epochs=90, epochs=90,
...@@ -90,8 +85,10 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -90,8 +85,10 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True), enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs. # Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
num_gpus: 1 num_gpus: 1
train_dataset: train_dataset:
...@@ -36,10 +34,13 @@ model: ...@@ -36,10 +34,13 @@ model:
num_classes: 1000 num_classes: 1000
batch_norm: 'default' batch_norm: 'default'
dtype: 'float32' dtype: 'float32'
activation: 'swish'
optimizer: optimizer:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs. # Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
train_dataset: train_dataset:
name: 'imagenet2012' name: 'imagenet2012'
...@@ -35,11 +33,12 @@ model: ...@@ -35,11 +33,12 @@ model:
num_classes: 1000 num_classes: 1000
batch_norm: 'tpu' batch_norm: 'tpu'
dtype: 'bfloat16' dtype: 'bfloat16'
activation: 'swish'
optimizer: optimizer:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0. moving_average_decay: 0.0
lookahead: false lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
......
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
num_gpus: 1 num_gpus: 1
train_dataset: train_dataset:
...@@ -12,6 +10,7 @@ train_dataset: ...@@ -12,6 +10,7 @@ train_dataset:
num_classes: 1000 num_classes: 1000
num_examples: 1281167 num_examples: 1281167
batch_size: 32 batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float32'
validation_dataset: validation_dataset:
name: 'imagenet2012' name: 'imagenet2012'
...@@ -21,6 +20,7 @@ validation_dataset: ...@@ -21,6 +20,7 @@ validation_dataset:
num_classes: 1000 num_classes: 1000
num_examples: 50000 num_examples: 50000
batch_size: 32 batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float32'
model: model:
model_params: model_params:
...@@ -29,10 +29,13 @@ model: ...@@ -29,10 +29,13 @@ model:
num_classes: 1000 num_classes: 1000
batch_norm: 'default' batch_norm: 'default'
dtype: 'float32' dtype: 'float32'
activation: 'swish'
optimizer: optimizer:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# Takes ~3 minutes, 15 seconds per epoch for v3-32. # Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
train_dataset: train_dataset:
name: 'imagenet2012' name: 'imagenet2012'
...@@ -34,10 +32,13 @@ model: ...@@ -34,10 +32,13 @@ model:
num_classes: 1000 num_classes: 1000
batch_norm: 'tpu' batch_norm: 'tpu'
dtype: 'bfloat16' dtype: 'bfloat16'
activation: 'swish'
optimizer: optimizer:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
# Training configuration for ResNet trained on ImageNet on GPUs. # Training configuration for ResNet trained on ImageNet on GPUs.
# Takes ~3 minutes, 15 seconds per epoch for 8 V100s. # Reaches > 76.1% within 90 epochs.
# Reaches ~76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
num_gpus: 1 num_gpus: 1
train_dataset: train_dataset:
name: 'imagenet2012' name: 'imagenet2012'
data_dir: null data_dir: null
builder: 'records' builder: 'tfds'
split: 'train' split: 'train'
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
...@@ -23,7 +20,7 @@ train_dataset: ...@@ -23,7 +20,7 @@ train_dataset:
validation_dataset: validation_dataset:
name: 'imagenet2012' name: 'imagenet2012'
data_dir: null data_dir: null
builder: 'records' builder: 'tfds'
split: 'validation' split: 'validation'
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
...@@ -34,7 +31,7 @@ validation_dataset: ...@@ -34,7 +31,7 @@ validation_dataset:
mean_subtract: True mean_subtract: True
standardize: True standardize: True
model: model:
model_name: 'resnet' name: 'resnet'
model_params: model_params:
rescale_inputs: False rescale_inputs: False
optimizer: optimizer:
......
# Training configuration for ResNet trained on ImageNet on TPUs. # Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~2 minutes, 43 seconds per epoch for a v3-32. # Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
# Reaches ~76.1% within 90 epochs. # Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices. # Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime: runtime:
model_dir: null
mode: 'train_and_eval'
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
train_dataset: train_dataset:
name: 'imagenet2012' name: 'imagenet2012'
data_dir: null data_dir: null
builder: 'records' builder: 'tfds'
split: 'train' split: 'train'
one_hot: False one_hot: False
image_size: 224 image_size: 224
...@@ -23,7 +21,7 @@ train_dataset: ...@@ -23,7 +21,7 @@ train_dataset:
validation_dataset: validation_dataset:
name: 'imagenet2012' name: 'imagenet2012'
data_dir: null data_dir: null
builder: 'records' builder: 'tfds'
split: 'validation' split: 'validation'
one_hot: False one_hot: False
image_size: 224 image_size: 224
...@@ -35,7 +33,7 @@ validation_dataset: ...@@ -35,7 +33,7 @@ validation_dataset:
standardize: True standardize: True
dtype: 'bfloat16' dtype: 'bfloat16'
model: model:
model_name: 'resnet' name: 'resnet'
model_params: model_params:
rescale_inputs: False rescale_inputs: False
optimizer: optimizer:
......
...@@ -23,7 +23,7 @@ import os ...@@ -23,7 +23,7 @@ import os
from typing import Any, List, Optional, Tuple, Mapping, Union from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging from absl import logging
from dataclasses import dataclass from dataclasses import dataclass
import tensorflow.compat.v2 as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
...@@ -84,11 +84,10 @@ class DatasetConfig(base_config.Config): ...@@ -84,11 +84,10 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores). (e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy. `strategy.num_replicas_in_sync` when using a distribution strategy.
data_format: The data format of the images. Should be 'channels_last' or
'channels_first'.
dtype: The desired dtype of the dataset. This will be set during dtype: The desired dtype of the dataset. This will be set during
preprocessing. preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
...@@ -118,9 +117,8 @@ class DatasetConfig(base_config.Config): ...@@ -118,9 +117,8 @@ class DatasetConfig(base_config.Config):
num_channels: Union[int, str] = 'infer' num_channels: Union[int, str] = 'infer'
num_examples: Union[int, str] = 'infer' num_examples: Union[int, str] = 'infer'
batch_size: int = 128 batch_size: int = 128
use_per_replica_batch_size: bool = False use_per_replica_batch_size: bool = True
num_devices: int = 1 num_devices: int = 1
data_format: str = 'channels_last'
dtype: str = 'float32' dtype: str = 'float32'
one_hot: bool = True one_hot: bool = True
augmenter: AugmentConfig = AugmentConfig() augmenter: AugmentConfig = AugmentConfig()
...@@ -188,14 +186,22 @@ class DatasetBuilder: ...@@ -188,14 +186,22 @@ class DatasetBuilder:
def batch_size(self) -> int: def batch_size(self) -> int:
"""The batch size, multiplied by the number of replicas (if configured).""" """The batch size, multiplied by the number of replicas (if configured)."""
if self.config.use_per_replica_batch_size: if self.config.use_per_replica_batch_size:
return self.global_batch_size return self.config.batch_size * self.config.num_devices
else: else:
return self.config.batch_size return self.config.batch_size
@property @property
def global_batch_size(self): def global_batch_size(self):
"""The global batch size across all replicas.""" """The global batch size across all replicas."""
return self.config.batch_size * self.config.num_devices return self.batch_size
@property
def local_batch_size(self):
"""The base unscaled batch size."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size
else:
return self.config.batch_size // self.config.num_devices
@property @property
def num_steps(self) -> int: def num_steps(self) -> int:
...@@ -203,6 +209,30 @@ class DatasetBuilder: ...@@ -203,6 +209,30 @@ class DatasetBuilder:
# Always divide by the global batch size to get the correct # of steps # Always divide by the global batch size to get the correct # of steps
return self.num_examples // self.global_batch_size return self.num_examples // self.global_batch_size
@property
def dtype(self) -> tf.dtypes.DType:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map = {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
try:
return dtype_map[self.config.dtype]
except:
raise ValueError('Invalid DType provided. Supported types: {}'.format(
dtype_map.keys()))
@property @property
def image_size(self) -> int: def image_size(self) -> int:
"""The size of each image (can be inferred from the dataset).""" """The size of each image (can be inferred from the dataset)."""
...@@ -243,19 +273,42 @@ class DatasetBuilder: ...@@ -243,19 +273,42 @@ class DatasetBuilder:
self.builder_info = tfds.builder(self.config.name).info self.builder_info = tfds.builder(self.config.name).info
return self.builder_info return self.builder_info
def build(self, input_context: tf.distribute.InputContext = None def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
) -> tf.data.Dataset: """Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected'
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
dataset = self._build()
return dataset
def _build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it. """Construct a dataset end-to-end and return it.
Args: Args:
input_context: An optional context provided by `tf.distribute` for input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras cross-replica training.
compile/fit.
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
""" """
builders = { builders = {
'tfds': self.load_tfds, 'tfds': self.load_tfds,
'records': self.load_records, 'records': self.load_records,
...@@ -326,7 +379,7 @@ class DatasetBuilder: ...@@ -326,7 +379,7 @@ class DatasetBuilder:
def generate_data(_): def generate_data(_):
image = tf.zeros([self.image_size, self.image_size, self.num_channels], image = tf.zeros([self.image_size, self.image_size, self.num_channels],
dtype=self.config.dtype) dtype=self.dtype)
label = tf.zeros([1], dtype=tf.int32) label = tf.zeros([1], dtype=tf.int32)
return image, label return image, label
...@@ -345,8 +398,8 @@ class DatasetBuilder: ...@@ -345,8 +398,8 @@ class DatasetBuilder:
Args: Args:
dataset: A `tf.data.Dataset` that loads raw files. dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras cross-replica training. If set with more than one replica, this
compile/fit. function assumes `use_per_replica_batch_size=True`.
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
...@@ -366,8 +419,6 @@ class DatasetBuilder: ...@@ -366,8 +419,6 @@ class DatasetBuilder:
cycle_length=16, cycle_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(self.global_batch_size)
if self.config.cache: if self.config.cache:
dataset = dataset.cache() dataset = dataset.cache()
...@@ -383,13 +434,25 @@ class DatasetBuilder: ...@@ -383,13 +434,25 @@ class DatasetBuilder:
dataset = dataset.map(preprocess, dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) if input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
# Note: we could do image normalization here, but we defer it to the model raise ValueError(
# which can perform it much faster on a GPU/TPU 'The builder does not support a global batch size with more than '
# TODO(dankondratyuk): if we fix prefetching, we can do it here 'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'.format(
self.config.num_devices))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size,
drop_remainder=self.is_training)
else:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
if self.is_training and self.config.deterministic_train is not None: if self.is_training:
options = tf.data.Options() options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack options.experimental_slack = self.config.use_slack
...@@ -400,9 +463,7 @@ class DatasetBuilder: ...@@ -400,9 +463,7 @@ class DatasetBuilder:
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks. dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Instead, use a constant prefetch size like the the number of devices.
dataset = dataset.prefetch(self.config.num_devices)
return dataset return dataset
...@@ -451,7 +512,7 @@ class DatasetBuilder: ...@@ -451,7 +512,7 @@ class DatasetBuilder:
image_size=self.image_size, image_size=self.image_size,
mean_subtract=self.config.mean_subtract, mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize, standardize=self.config.standardize,
dtype=self.config.dtype, dtype=self.dtype,
augmenter=self.augmenter) augmenter=self.augmenter)
else: else:
image = preprocessing.preprocess_for_eval( image = preprocessing.preprocess_for_eval(
...@@ -460,7 +521,7 @@ class DatasetBuilder: ...@@ -460,7 +521,7 @@ class DatasetBuilder:
num_channels=self.num_channels, num_channels=self.num_channels,
mean_subtract=self.config.mean_subtract, mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize, standardize=self.config.standardize,
dtype=self.config.dtype) dtype=self.dtype)
label = tf.cast(label, tf.int32) label = tf.cast(label, tf.int32)
if self.config.one_hot: if self.config.one_hot:
......
...@@ -19,15 +19,14 @@ from __future__ import division ...@@ -19,15 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1 import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from typing import Text, Optional from typing import Text, Optional
from tensorflow.python.tpu import tpu_function from tensorflow.python.tpu import tpu_function
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Vision')
class TpuBatchNormalization(tf.keras.layers.BatchNormalization): class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
"""Cross replica batch normalization.""" """Cross replica batch normalization."""
...@@ -98,3 +97,21 @@ def count_params(model, trainable_only=True): ...@@ -98,3 +97,21 @@ def count_params(model, trainable_only=True):
else: else:
return int(np.sum([tf.keras.backend.count_params(p) return int(np.sum([tf.keras.backend.count_params(p)
for p in model.trainable_weights])) for p in model.trainable_weights]))
def load_weights(model: tf.keras.Model,
model_weights_path: Text,
weights_format: Text = 'saved_model'):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
"""
if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path)
model.set_weights(loaded_model.get_weights())
else:
model.load_weights(model_weights_path)
...@@ -22,6 +22,7 @@ from typing import Any, Mapping ...@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
...@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig): ...@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
configuration. configuration.
learning_rate: The configuration for learning rate. Defaults to an learning_rate: The configuration for learning rate. Defaults to an
exponential configuration. exponential configuration.
""" """
name: str = 'EfficientNet' name: str = 'EfficientNet'
num_classes: int = 1000 num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { model_params: base_config.Config = dataclasses.field(
'model_name': 'efficientnet-b0', default_factory=lambda: {
'model_weights_path': '', 'model_name': 'efficientnet-b0',
'copy_to_local': False, 'model_weights_path': '',
'overrides': { 'weights_format': 'saved_model',
'batch_norm': 'default', 'overrides': {
'rescale_input': True, 'batch_norm': 'default',
'num_classes': 1000, 'rescale_input': True,
} 'num_classes': 1000,
}) 'activation': 'swish',
'dtype': 'float32',
}
})
loss: base_configs.LossConfig = base_configs.LossConfig( loss: base_configs.LossConfig = base_configs.LossConfig(
name='categorical_crossentropy', name='categorical_crossentropy', label_smoothing=0.1)
label_smoothing=0.1)
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig( optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='rmsprop', name='rmsprop',
decay=0.9, decay=0.9,
...@@ -72,4 +74,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig): ...@@ -72,4 +74,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
decay_epochs=2.4, decay_epochs=2.4,
decay_rate=0.97, decay_rate=0.97,
warmup_epochs=5, warmup_epochs=5,
scale_by_batch_size=1. / 128.) scale_by_batch_size=1. / 128.,
staircase=True)
...@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple ...@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple
from absl import logging from absl import logging
from dataclasses import dataclass from dataclasses import dataclass
import tensorflow.compat.v2 as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
...@@ -104,6 +104,8 @@ MODEL_CONFIGS = { ...@@ -104,6 +104,8 @@ MODEL_CONFIGS = {
'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4), 'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4),
'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5), 'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5),
'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5), 'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5),
'efficientnet-b8': ModelConfig.from_args(2.2, 3.6, 672, 0.5),
'efficientnet-l2': ModelConfig.from_args(4.3, 5.3, 800, 0.5),
} }
CONV_KERNEL_INITIALIZER = { CONV_KERNEL_INITIALIZER = {
...@@ -166,7 +168,7 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -166,7 +168,7 @@ def conv2d_block(inputs: tf.Tensor,
batch_norm = common_modules.get_batch_norm(config.batch_norm) batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum bn_momentum = config.bn_momentum
bn_epsilon = config.bn_epsilon bn_epsilon = config.bn_epsilon
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
weight_decay = config.weight_decay weight_decay = config.weight_decay
name = name or '' name = name or ''
...@@ -223,7 +225,7 @@ def mb_conv_block(inputs: tf.Tensor, ...@@ -223,7 +225,7 @@ def mb_conv_block(inputs: tf.Tensor,
use_se = config.use_se use_se = config.use_se
activation = tf_utils.get_activation(config.activation) activation = tf_utils.get_activation(config.activation)
drop_connect_rate = config.drop_connect_rate drop_connect_rate = config.drop_connect_rate
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
use_depthwise = block.conv_type != 'no_depthwise' use_depthwise = block.conv_type != 'no_depthwise'
prefix = prefix or '' prefix = prefix or ''
...@@ -346,12 +348,14 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -346,12 +348,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
num_classes = config.num_classes num_classes = config.num_classes
input_channels = config.input_channels input_channels = config.input_channels
rescale_input = config.rescale_input rescale_input = config.rescale_input
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
dtype = config.dtype dtype = config.dtype
weight_decay = config.weight_decay weight_decay = config.weight_decay
x = image_input x = image_input
if data_format == 'channels_first':
# Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input: if rescale_input:
x = preprocessing.normalize_images(x, x = preprocessing.normalize_images(x,
num_channels=input_channels, num_channels=input_channels,
...@@ -463,7 +467,7 @@ class EfficientNet(tf.keras.Model): ...@@ -463,7 +467,7 @@ class EfficientNet(tf.keras.Model):
def from_name(cls, def from_name(cls,
model_name: Text, model_name: Text,
model_weights_path: Text = None, model_weights_path: Text = None,
copy_to_local: bool = False, weights_format: Text = 'saved_model',
overrides: Dict[Text, Any] = None): overrides: Dict[Text, Any] = None):
"""Construct an EfficientNet model from a predefined model name. """Construct an EfficientNet model from a predefined model name.
...@@ -472,7 +476,8 @@ class EfficientNet(tf.keras.Model): ...@@ -472,7 +476,8 @@ class EfficientNet(tf.keras.Model):
Args: Args:
model_name: the predefined model name model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir) model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
overrides: (optional) a dict containing keys that can override config overrides: (optional) a dict containing keys that can override config
Returns: Returns:
...@@ -492,12 +497,8 @@ class EfficientNet(tf.keras.Model): ...@@ -492,12 +497,8 @@ class EfficientNet(tf.keras.Model):
model = cls(config=config, overrides=overrides) model = cls(config=config, overrides=overrides)
if model_weights_path: if model_weights_path:
if copy_to_local: common_modules.load_weights(model,
tmp_file = os.path.join('/tmp', model_name + '.h5') model_weights_path,
model_weights_file = os.path.join(model_weights_path, 'model.h5') weights_format=weights_format)
tf.io.gfile.copy(model_weights_file, tmp_file, overwrite=True)
model_weights_path = tmp_file
model.load_weights(model_weights_path)
return model return model
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS = flags.FLAGS
flags.DEFINE_string("model_name", None,
"EfficientNet model name.")
flags.DEFINE_string("model_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path to export.")
def export_tfhub(model_path, hub_destination, model_name):
"""Restores a tf.keras.Model and saves for TF-Hub."""
model = efficientnet_model.EfficientNet.from_name(model_name)
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(model_path).assert_existing_objects_matched()
image_input = tf.keras.layers.Input(
shape=(None, None, 3), name="image_input", dtype=tf.float32)
x = image_input * 255.0
ouputs = model(x)
hub_model = tf.keras.Model(image_input, ouputs)
# Exports a SavedModel.
hub_model.save(
os.path.join(hub_destination, "classification"), include_optimizer=False)
feature_vector_output = hub_model.get_layer(name="efficientnet").get_layer(
name="top_pool").get_output_at(0)
hub_model2 = tf.keras.Model(model.inputs, feature_vector_output)
# Exports a SavedModel.
hub_model2.save(
os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__":
app.run(main)
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from typing import Any, List, Mapping from typing import Any, List, Mapping
import tensorflow.compat.v2 as tf import tensorflow as tf
BASE_LEARNING_RATE = 0.1 BASE_LEARNING_RATE = 0.1
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow.compat.v2 as tf import tensorflow as tf
from official.vision.image_classification import learning_rate from official.vision.image_classification import learning_rate
...@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase): ...@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment