Commit c57e975a authored by saberkun's avatar saberkun
Browse files

Merge pull request #10338 from srihari-humbarwadi:readme

PiperOrigin-RevId: 413033276
parents 7fb4f3cd acf4156e
......@@ -60,7 +60,7 @@ In the near future, we will add:
|-------|-------------------|
| [ALBERT (A Lite BERT)](nlp/albert) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
| [BERT (Bidirectional Encoder Representations from Transformers)](nlp/bert) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
| [NHNet (News Headline generation model)](nlp/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/transformer) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
| [XLNet](nlp/xlnet) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) |
| [MobileBERT](nlp/projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
......
......@@ -141,8 +141,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
raise ValueError(f"When {num_gpus} GPUs are specified, "
"distribution_strategy flag cannot be set to `off`.")
# Return the default distribution strategy.
return tf.distribute.get_strategy()
......
......@@ -12,39 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tests for distribution util functions."""
"""Tests for distribution util functions."""
import sys
import tensorflow as tf
from official.common import distribute_utils
TPU_TEST = 'test_tpu' in sys.argv[0]
class GetDistributionStrategyTest(tf.test.TestCase):
"""Tests for get_distribution_strategy."""
class DistributeUtilsTest(tf.test.TestCase):
"""Tests for distribute util functions."""
def test_invalid_args(self):
with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
_ = distribute_utils.get_distribution_strategy(num_gpus=-1)
with self.assertRaisesRegex(ValueError,
'.*If you meant to pass the string .*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy=False, num_gpus=0)
with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='off', num_gpus=2)
with self.assertRaisesRegex(ValueError,
'`OneDeviceStrategy` can not be used.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='one_device', num_gpus=2)
def test_one_device_strategy_cpu(self):
ds = distribute_utils.get_distribution_strategy(num_gpus=0)
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self):
ds = distribute_utils.get_distribution_strategy(num_gpus=1)
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self):
# CPU only.
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
# 5 GPUs.
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices:
self.assertIn('GPU', device)
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='nccl',
num_packs=2)
with self.assertRaisesRegex(
ValueError,
'When used with `mirrored`, valid values for all_reduce_alg are.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='dummy',
num_packs=2)
def test_mwms(self):
distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
ds = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='nccl')
self.assertIsInstance(
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
with self.assertRaisesRegex(
ValueError,
'When used with `multi_worker_mirrored`, valid values.*'):
_ = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='dummy')
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIs(ds, tf.distribute.get_strategy())
def test_tpu_strategy(self):
if not TPU_TEST:
self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
with self.assertRaises(ValueError):
_ = distribute_utils.get_distribution_strategy('tpu')
ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
self.assertIsInstance(
ds, tf.distribute.TPUStrategy)
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
......@@ -54,6 +113,12 @@ class GetDistributionStrategyTest(tf.test.TestCase):
ValueError, 'distribution_strategy must be a string but got: 1'):
distribute_utils.get_distribution_strategy(1)
def test_get_strategy_scope(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
with distribute_utils.get_strategy_scope(ds):
self.assertIs(tf.distribute.get_strategy(), ds)
with distribute_utils.get_strategy_scope(None):
self.assertIsNot(tf.distribute.get_strategy(), ds)
if __name__ == '__main__':
tf.test.main()
......@@ -28,7 +28,7 @@ from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
class PruningAction:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
......@@ -66,7 +66,7 @@ class PruningActions:
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
output: The train output.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
......@@ -81,8 +81,11 @@ class EMACheckpointing:
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
def __init__(self,
export_dir: str,
optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint,
max_to_keep: int = 1):
"""Initializes the instance.
Args:
......@@ -99,8 +102,7 @@ class EMACheckpointing:
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
tf.io.gfile.makedirs(os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
......@@ -113,7 +115,7 @@ class EMACheckpointing:
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
output: The train or eval output.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
......@@ -173,10 +175,9 @@ class RecoveryCondition:
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
def get_eval_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
eval_actions = []
# Adds ema checkpointing action to save the average weights under
......@@ -202,7 +203,7 @@ def get_train_actions(
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
PruningAction(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
......
......@@ -27,14 +27,16 @@ from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
class TestModel(tf.keras.Model):
def __init__(self):
self.value = tf.Variable(0)
super().__init__()
self.value = tf.Variable(0.0)
self.dense = tf.keras.layers.Dense(2)
_ = self.dense(tf.zeros((2, 2), tf.float32))
@tf.function(input_signature=[])
def __call__(self):
return self.value
def call(self, x, training=None):
return self.value + x
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
......@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
......@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
self.assertEqual(model(0.), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(
tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
self.assertEqual(model(0.), 0)
# Raises an error for a normal optimizer.
with self.assertRaisesRegex(ValueError,
'Optimizer has to be instance of.*'):
_ = actions.EMACheckpointing(directory, tf.keras.optimizers.SGD(),
checkpoint)
@combinations.generate(
combinations.combine(
......@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(RuntimeError):
recover_condition(outputs)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_pruning(self, distribution):
with distribution.scope():
directory = self.get_temp_dir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
pruning = actions.PruningAction(directory, model, optimizer)
pruning({})
if __name__ == '__main__':
tf.test.main()
......@@ -75,8 +75,8 @@ class Recovery:
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN after training loop and it happens %d times." %
self.recover_counter)
"The loss value is NaN or out of range after training loop and "
f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
......@@ -247,14 +247,12 @@ class Trainer(_AsyncTrainer):
self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
self.init_async()
if train:
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
......@@ -266,6 +264,8 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
......@@ -370,16 +370,6 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
return self._checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
self._recovery = Recovery(
loss_upper_bound=params.loss_upper_bound,
recovery_begin_steps=params.recovery_begin_steps,
recovery_max_trials=params.recovery_max_trials,
checkpoint_manager=checkpoint_manager)
def train_loop_end(self):
"""See base class."""
self.join()
......
......@@ -150,6 +150,30 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return self.eval_global_step.numpy()
class RecoveryTest(tf.test.TestCase):
def test_recovery_module(self):
ckpt = tf.train.Checkpoint(v=tf.Variable(1, dtype=tf.int32))
model_dir = self.get_temp_dir()
manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
recovery_module = trainer_lib.Recovery(
loss_upper_bound=1.0,
checkpoint_manager=manager,
recovery_begin_steps=1,
recovery_max_trials=1)
self.assertFalse(recovery_module.should_recover(1.1, 0))
self.assertFalse(recovery_module.should_recover(0.1, 1))
self.assertTrue(recovery_module.should_recover(1.1, 2))
# First triggers the recovery once.
recovery_module.maybe_recover(1.1, 10)
# Second time, it raises.
with self.assertRaisesRegex(
RuntimeError, 'The loss value is NaN .*'):
recovery_module.maybe_recover(1.1, 10)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......
......@@ -28,13 +28,35 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self,
params,
model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None):
inference_step: Optional[Callable[..., Any]] = None,
*,
preprocessor: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
preprocessor: An optional callable to preprocess the inputs.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(name=None)
self.model = model
......@@ -45,6 +67,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]:
......
......@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule):
@tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)}
x = inputs if self.preprocessor is None else self.preprocessor(
inputs=inputs)
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return {'outputs': x}
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
......@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase):
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_processors(self):
model = tf.Module()
inputs = tf.zeros((), tf.float32)
def _inference_step(inputs, model):
del model
return inputs + 1.0
def _preprocessor(inputs):
print(inputs)
return inputs + 0.1
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor)
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.1)
class _PostProcessor(tf.Module):
def __call__(self, inputs):
return inputs + 0.01
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor,
postprocessor=_PostProcessor())
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
if __name__ == '__main__':
tf.test.main()
# Image Classification
**Warning:** the features in the `image_classification/` folder have been fully
integrated into vision/beta. Please use the [new code base](../../vision/beta/README.md).
This folder contains TF 2.0 model examples for image classification:
* [MNIST](#mnist)
* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras
compile/fit methods for image classification models, including:
* ResNet
* EfficientNet[^1]
[^1]: Currently a work in progress. We cannot match "AutoAugment (AA)" in [the original version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet).
For more information about other types of models, please refer to this
[README file](../../README.md).
## Before you begin
Please make sure that you have the latest version of TensorFlow
installed and
[add the models folder to your Python path](/official/#running-the-models).
### ImageNet preparation
#### Using TFDS
`classifier_trainer.py` supports ImageNet with
[TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/overview).
Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py)
for more information on how to use TFDS to download and prepare datasets, and
specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md)
for manual download instructions.
#### Legacy TFRecords
Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](resnet/resnet_ctl_imagenet_main.py)
require TFRecords whereas `classifier_trainer.py` can use both by setting the
builder to 'records' or 'tfds' in the configurations.
### Running on Cloud TPUs
Note: These models will **not** work with TPUs on Colab.
You can train image classification models on Cloud TPUs using
[tf.distribute.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf.distribute.TPUStrategy?version=nightly).
If you are not familiar with Cloud TPUs, it is strongly recommended that you go
through the
[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
create a TPU and GCE VM.
### Running on multiple GPU hosts
You can also train these models on multiple hosts, each with GPUs, using
[tf.distribute.experimental.MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy).
The easiest way to run multi-host benchmarks is to set the
[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
available GPUs at each host.
## MNIST
To download the data and run the MNIST sample model locally for the first time,
run one of the following command:
```bash
python3 mnist_main.py \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--train_epochs=10 \
--distribution_strategy=one_device \
--num_gpus=$NUM_GPUS \
--download
```
To train the model on a Cloud TPU, run the following command:
```bash
python3 mnist_main.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
```
Note: the `--download` flag is only required the first time you run the model.
## Classifier Trainer
The classifier trainer is a unified framework for running image classification
models using Keras's compile/fit methods. Experiments should be provided in the
form of YAML files, some examples are included within the configs/examples
folder. Please see [configs/examples](./configs/examples) for more example
configurations.
The provided configuration files use a per replica batch size and is scaled
by the number of devices. For instance, if `batch size` = 64, then for 1 GPU
the global batch size would be 64 * 1 = 64. For 8 GPUs, the global batch size
would be 64 * 8 = 512. Similarly, for a v3-8 TPU, the global batch size would
be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048.
### ResNet50
#### On GPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=resnet \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/resnet/imagenet/gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
To train on multiple hosts, each with GPUs attached using
[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
please update `runtime` section in gpu.yaml
(or override using `--params_override`) with:
```YAML
# gpu.yaml
runtime:
distribution_strategy: 'multi_worker_mirrored'
worker_hosts: '$HOST1:port,$HOST2:port'
num_gpus: $NUM_GPUS
task_index: 0
```
By having `task_index: 0` on the first host and `task_index: 1` on the second
and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
can be chosen any free port on the hosts. Only the first host will write
TensorBoard Summaries and save checkpoints.
#### On TPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=resnet \
--dataset=imagenet \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/resnet/imagenet/tpu.yaml
```
### EfficientNet
**Note: EfficientNet development is a work in progress.**
#### On GPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=efficientnet \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
#### On TPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=efficientnet \
--dataset=imagenet \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
```
Note that the number of GPU devices can be overridden in the command line using
`--params_overrides`. The TPU does not need this override as the device is fixed
by providing the TPU address or name with the `--tpu` flag.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
"""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import math
from typing import Any, Dict, List, Optional, Text, Tuple
from keras.layers.preprocessing import image_preprocessing as image_ops
import tensorflow as tf
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.
def to_4d(image: tf.Tensor) -> tf.Tensor:
"""Converts an input Tensor to 4 dimensions.
4D image => [N, H, W, C] or [N, C, H, W]
3D image => [1, H, W, C] or [1, C, H, W]
2D image => [1, H, W, 1]
Args:
image: The 2/3/4D input tensor.
Returns:
A 4D image tensor.
Raises:
`TypeError` if `image` is not a 2/3/4D tensor.
"""
shape = tf.shape(image)
original_rank = tf.rank(image)
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
new_shape = tf.concat(
[
tf.ones(shape=left_pad, dtype=tf.int32),
shape,
tf.ones(shape=right_pad, dtype=tf.int32),
],
axis=0,
)
return tf.reshape(image, new_shape)
def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
"""Converts a 4D image back to `ndims` rank."""
shape = tf.shape(image)
begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32)
new_shape = shape[begin:end]
return tf.reshape(image, new_shape)
def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
[[1 0 -dx]
[0 1 -dy]
[0 0 1]]
Args:
translations: The 2-element list representing [dx, dy], or a matrix of
2-element lists representing [dx dy] to translate for each image. The
shape must be static.
Returns:
The transformation matrix of shape (num_images, 8).
Raises:
`TypeError` if
- the shape of `translations` is not known or
- the shape of `translations` is not rank 1 or 2.
"""
translations = tf.convert_to_tensor(translations, dtype=tf.float32)
if translations.get_shape().ndims is None:
raise TypeError('translations rank must be statically known')
elif len(translations.get_shape()) == 1:
translations = translations[None]
elif len(translations.get_shape()) != 2:
raise TypeError('translations should have rank 1 or 2.')
num_translations = tf.shape(translations)[0]
return tf.concat(
values=[
tf.ones((num_translations, 1), tf.dtypes.float32),
tf.zeros((num_translations, 1), tf.dtypes.float32),
-translations[:, 0, None],
tf.zeros((num_translations, 1), tf.dtypes.float32),
tf.ones((num_translations, 1), tf.dtypes.float32),
-translations[:, 1, None],
tf.zeros((num_translations, 2), tf.dtypes.float32),
],
axis=1,
)
def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform.
Args:
angles: A scalar to rotate all images, or a vector to rotate a batch of
images. This must be a scalar.
image_width: The width of the image(s) to be transformed.
image_height: The height of the image(s) to be transformed.
Returns:
A tensor of shape (num_images, 8).
Raises:
`TypeError` if `angles` is not rank 0 or 1.
"""
angles = tf.convert_to_tensor(angles, dtype=tf.float32)
if len(angles.get_shape()) == 0: # pylint:disable=g-explicit-length-test
angles = angles[None]
elif len(angles.get_shape()) != 1:
raise TypeError('Angles should have a rank 0 or 1.')
x_offset = ((image_width - 1) -
(tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) *
(image_height - 1))) / 2.0
y_offset = ((image_height - 1) -
(tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) *
(image_height - 1))) / 2.0
num_angles = tf.shape(angles)[0]
return tf.concat(
values=[
tf.math.cos(angles)[:, None],
-tf.math.sin(angles)[:, None],
x_offset[:, None],
tf.math.sin(angles)[:, None],
tf.math.cos(angles)[:, None],
y_offset[:, None],
tf.zeros((num_angles, 2), tf.dtypes.float32),
],
axis=1,
)
def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if transforms.shape.rank == 1:
transforms = transforms[None]
image = to_4d(image)
image = image_ops.transform(
images=image, transforms=transforms, interpolation='nearest')
return from_4d(image, original_ndims)
def translate(image: tf.Tensor, translations) -> tf.Tensor:
"""Translates image(s) by provided vectors.
Args:
image: An image Tensor of type uint8.
translations: A vector or matrix representing [dx dy].
Returns:
The translated version of the image.
"""
transforms = _convert_translation_to_transform(translations)
return transform(image, transforms=transforms)
def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
"""Rotates the image by degrees either clockwise or counterclockwise.
Args:
image: An image Tensor of type uint8.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
Returns:
The rotated version of image.
"""
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = tf.cast(degrees * degrees_to_radians, tf.float32)
original_ndims = tf.rank(image)
image = to_4d(image)
image_height = tf.cast(tf.shape(image)[1], tf.float32)
image_width = tf.cast(tf.shape(image)[2], tf.float32)
transforms = _convert_angles_to_transform(
angles=radians, image_width=image_width, image_height=image_height)
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
image = transform(image, transforms=transforms)
return from_4d(image, original_ndims)
def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor:
"""Blend image1 and image2 using 'factor'.
Factor can be above 0.0. A value of 0.0 means only image1 is used.
A value of 1.0 means only image2 is used. A value between 0.0 and
1.0 means we linearly interpolate the pixel values between the two
images. A value greater than 1.0 "extrapolates" the difference
between the two pixel values, and we clip the results to values
between 0 and 255.
Args:
image1: An image Tensor of type uint8.
image2: An image Tensor of type uint8.
factor: A floating point value above 0.0.
Returns:
A blended image Tensor of type uint8.
"""
if factor == 0.0:
return tf.convert_to_tensor(image1)
if factor == 1.0:
return tf.convert_to_tensor(image2)
image1 = tf.cast(image1, tf.float32)
image2 = tf.cast(image2, tf.float32)
difference = image2 - image1
scaled = factor * difference
# Do addition in float.
temp = tf.cast(image1, tf.float32) + scaled
# Interpolate
if factor > 0.0 and factor < 1.0:
# Interpolation means we always stay within 0 and 255.
return tf.cast(temp, tf.uint8)
# Extrapolate:
#
# We need to clip and then cast.
return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
a random location within `img`. The pixel values filled in will be of the
value `replace`. The located where the mask will be applied is randomly
chosen uniformly over the whole image.
Args:
image: An image Tensor of type uint8.
pad_size: Specifies how big the zero mask that will be generated is that is
applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
replace: What pixel value to fill in the image in the area that has the
cutout mask applied to it.
Returns:
An image Tensor that is of type uint8.
"""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
shape=[], minval=0, maxval=image_height, dtype=tf.int32)
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [
image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0),
tf.ones_like(image, dtype=image.dtype) * replace, image)
return image
def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
# For each pixel in the image, select the pixel
# if the value is less than the threshold.
# Otherwise, subtract 255 from the pixel.
return tf.where(image < threshold, image, 255 - image)
def solarize_add(image: tf.Tensor,
addition: int = 0,
threshold: int = 128) -> tf.Tensor:
# For each pixel in the image less than threshold
# we add 'addition' amount to it and then clip the
# pixel value to be between 0 and 255. The value
# of 'addition' is between -128 and 128.
added_image = tf.cast(image, tf.int64) + addition
added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
return tf.where(image < threshold, added_image, image)
def color(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Color."""
degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
return blend(degenerate, image, factor)
def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Contrast."""
degenerate = tf.image.rgb_to_grayscale(image)
# Cast before calling tf.histogram.
degenerate = tf.cast(degenerate, tf.int32)
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
return blend(degenerate, image, factor)
def brightness(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Brightness."""
degenerate = tf.zeros_like(image)
return blend(degenerate, image, factor)
def posterize(image: tf.Tensor, bits: int) -> tf.Tensor:
"""Equivalent of PIL Posterize."""
shift = 8 - bits
return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor:
"""Applies rotation with wrap/unwrap."""
image = rotate(wrap(image), degrees=degrees)
return unwrap(image, replace)
def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
"""Equivalent of PIL Translate in X dimension."""
image = translate(wrap(image), [-pixels, 0])
return unwrap(image, replace)
def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
"""Equivalent of PIL Translate in Y dimension."""
image = translate(wrap(image), [0, -pixels])
return unwrap(image, replace)
def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
"""Equivalent of PIL Shearing in X dimension."""
# Shear parallel to x axis is a projective transform
# with a matrix form of:
# [1 level
# 0 1].
image = transform(
image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
return unwrap(image, replace)
def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
"""Equivalent of PIL Shearing in Y dimension."""
# Shear parallel to y axis is a projective transform
# with a matrix form of:
# [1 0
# level 1].
image = transform(
image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
return unwrap(image, replace)
def autocontrast(image: tf.Tensor) -> tf.Tensor:
"""Implements Autocontrast function from PIL using TF ops.
Args:
image: A 3D uint8 tensor.
Returns:
The image after it has had autocontrast applied to it and will be of type
uint8.
"""
def scale_channel(image: tf.Tensor) -> tf.Tensor:
"""Scale the 2D image using the autocontrast rule."""
# A possibly cheaper version can be done using cumsum/unique_with_counts
# over the histogram values, rather than iterating over the entire image.
# to compute mins and maxes.
lo = tf.cast(tf.reduce_min(image), tf.float32)
hi = tf.cast(tf.reduce_max(image), tf.float32)
# Scale the image, making the lowest value 0 and the highest value 255.
def scale_values(im):
scale = 255.0 / (hi - lo)
offset = -lo * scale
im = tf.cast(im, tf.float32) * scale + offset
im = tf.clip_by_value(im, 0.0, 255.0)
return tf.cast(im, tf.uint8)
result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
return result
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image[:, :, 0])
s2 = scale_channel(image[:, :, 1])
s3 = scale_channel(image[:, :, 2])
image = tf.stack([s1, s2, s3], 2)
return image
def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Implements Sharpness function from PIL using TF ops."""
orig_image = image
image = tf.cast(image, tf.float32)
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13.
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
degenerate = tf.nn.depthwise_conv2d(
image, kernel, strides, padding='VALID', dilations=[1, 1])
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
# For the borders of the resulting image, fill in the values of the
# original image.
mask = tf.ones_like(degenerate)
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
# Blend the final result.
return blend(result, orig_image, factor)
def equalize(image: tf.Tensor) -> tf.Tensor:
"""Implements Equalize function from PIL using TF ops."""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = tf.cast(im[:, :, c], tf.int32)
# Compute the histogram of the image channel.
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
# For the purposes of computing the step, filter out the nonzeros.
nonzero = tf.where(tf.not_equal(histo, 0))
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (tf.cumsum(histo) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = tf.concat([[0], lut[:-1]], 0)
# Clip the counts to be in range. This is done
# in the C code for image.point.
return tf.clip_by_value(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(
tf.equal(step, 0), lambda: im,
lambda: tf.gather(build_lut(histo, step), im))
return tf.cast(result, tf.uint8)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = tf.stack([s1, s2, s3], 2)
return image
def invert(image: tf.Tensor) -> tf.Tensor:
"""Inverts the image pixels."""
image = tf.convert_to_tensor(image)
return 255 - image
def wrap(image: tf.Tensor) -> tf.Tensor:
"""Returns 'image' with an extra channel set to all 1s."""
shape = tf.shape(image)
extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
extended = tf.concat([image, extended_channel], axis=2)
return extended
def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
"""Unwraps an image produced by wrap.
Where there is a 0 in the last channel for every spatial position,
the rest of the three channels in that spatial dimension are grayed
(set to 128). Operations like translate and shear on a wrapped
Tensor will leave 0s in empty locations. Some transformations look
at the intensity of values to do preprocessing, and we want these
empty pixels to assume the 'average' value, rather than pure black.
Args:
image: A 3D Image Tensor with 4 channels.
replace: A one or three value 1D tensor to fill empty pixels.
Returns:
image: A 3D image Tensor with 3 channels.
"""
image_shape = tf.shape(image)
# Flatten the spatial dimensions.
flattened_image = tf.reshape(image, [-1, image_shape[2]])
# Find all pixels where the last channel is zero.
alpha_channel = tf.expand_dims(flattened_image[:, 3], axis=-1)
replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
# Where they are zero, fill them in with 'replace'.
flattened_image = tf.where(
tf.equal(alpha_channel, 0),
tf.ones_like(flattened_image, dtype=image.dtype) * replace,
flattened_image)
image = tf.reshape(flattened_image, image_shape)
image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
return image
def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
return final_tensor
def _rotate_level_to_arg(level: float):
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate_tensor(level)
return (level,)
def _shrink_level_to_arg(level: float):
"""Converts level to ratio by which we shrink the image content."""
if level == 0:
return (1.0,) # if level is zero, do not shrink the image
# Maximum shrinking ratio is 2.9.
level = 2. / (_MAX_LEVEL / level) + 0.9
return (level,)
def _enhance_level_to_arg(level: float):
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level: float):
level = (level / _MAX_LEVEL) * 0.3
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
def _translate_level_to_arg(level: float, translate_const: float):
level = (level / _MAX_LEVEL) * float(translate_const)
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
def _mult_to_arg(level: float, multiplier: float = 1.):
return (int((level / _MAX_LEVEL) * multiplier),)
def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple)
# Apply the function with probability `prob`.
should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
lambda: image)
return augmented_image
def select_and_apply_random_policy(policies: Any, image: tf.Tensor):
"""Select a random policy from `policies` and apply it to `image`."""
policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
# Note that using tf.case instead of tf.conds would result in significantly
# larger graphs and would even break export for some larger policies.
for (i, policy) in enumerate(policies):
image = tf.cond(
tf.equal(i, policy_to_select),
lambda selected_policy=policy: selected_policy(image),
lambda: image)
return image
NAME_TO_FUNC = {
'AutoContrast': autocontrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': wrapped_rotate,
'Posterize': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
'Brightness': brightness,
'Sharpness': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x,
'TranslateY': translate_y,
'Cutout': cutout,
}
# Functions that have a 'replace' parameter
REPLACE_FUNCS = frozenset({
'Rotate',
'TranslateX',
'ShearX',
'ShearY',
'TranslateY',
'Cutout',
})
def level_to_arg(cutout_const: float, translate_const: float):
"""Creates a dict mapping image operation names to their arguments."""
no_arg = lambda level: ()
posterize_arg = lambda level: _mult_to_arg(level, 4)
solarize_arg = lambda level: _mult_to_arg(level, 256)
solarize_add_arg = lambda level: _mult_to_arg(level, 110)
cutout_arg = lambda level: _mult_to_arg(level, cutout_const)
translate_arg = lambda level: _translate_level_to_arg(level, translate_const)
args = {
'AutoContrast': no_arg,
'Equalize': no_arg,
'Invert': no_arg,
'Rotate': _rotate_level_to_arg,
'Posterize': posterize_arg,
'Solarize': solarize_arg,
'SolarizeAdd': solarize_add_arg,
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'Cutout': cutout_arg,
'TranslateX': translate_arg,
'TranslateY': translate_arg,
}
return args
def _parse_policy_info(name: Text, prob: float, level: float,
replace_value: List[int], cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
args = level_to_arg(cutout_const, translate_const)[name](level)
if name in REPLACE_FUNCS:
# Add in replace arg if it is required for the function that is called.
args = tuple(list(args) + [replace_value])
return func, prob, args
class ImageAugment(object):
"""Image augmentation class for applying image distortions."""
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Given an image tensor, returns a distorted image with the same shape.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
The augmented version of `image`.
"""
raise NotImplementedError()
class AutoAugment(ImageAugment):
"""Applies the AutoAugment policy to images.
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
"""
def __init__(self,
augmentation_name: Text = 'v0',
policies: Optional[Dict[Text, Any]] = None,
cutout_const: float = 100,
translate_const: float = 250):
"""Applies the AutoAugment policy to images.
Args:
augmentation_name: The name of the AutoAugment policy to use. The
available options are `v0` and `test`. `v0` is the policy used for all
of the results in the paper and was found to achieve the best results on
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were
used during the search procedure along with how many operations are
applied in parallel to a single image (2 vs 3).
policies: list of lists of tuples in the form `(func, prob, level)`,
`func` is a string name of the augmentation function, `prob` is the
probability of applying the `func` operation, `level` is the input
argument for `func`.
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
"""
super(AutoAugment, self).__init__()
if policies is None:
self.available_policies = {
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
}
if augmentation_name not in self.available_policies:
raise ValueError(
'Invalid augmentation_name: {}'.format(augmentation_name))
self.augmentation_name = augmentation_name
self.policies = self.available_policies[augmentation_name]
self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const)
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the AutoAugment policy to `image`.
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
A version of image that now has data augmentation applied to it based on
the `policies` pass into the function.
"""
input_image_type = image.dtype
if input_image_type != tf.uint8:
image = tf.clip_by_value(image, 0.0, 255.0)
image = tf.cast(image, dtype=tf.uint8)
replace_value = [128] * 3
# func is the string name of the augmentation function, prob is the
# probability of applying the operation and level is the parameter
# associated with the tf op.
# tf_policies are functions that take in an image and return an augmented
# image.
tf_policies = []
for policy in self.policies:
tf_policy = []
# Link string name to the correct python function and make sure the
# correct argument is passed into that function.
for policy_info in policy:
policy_info = list(policy_info) + [
replace_value, self.cutout_const, self.translate_const
]
tf_policy.append(_parse_policy_info(*policy_info))
# Now build the tf policy that will apply the augmentation procedue
# on image.
def make_final_policy(tf_policy_):
def final_policy(image_):
for func, prob, args in tf_policy_:
image_ = _apply_func_with_prob(func, image_, args, prob)
return image_
return final_policy
tf_policies.append(make_final_policy(tf_policy))
image = select_and_apply_random_policy(tf_policies, image)
image = tf.cast(image, dtype=input_image_type)
return image
@staticmethod
def policy_v0():
"""Autoaugment policy that was used in AutoAugment Paper.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
# TODO(dankondratyuk): tensorflow_addons defines custom ops, which
# for some reason are not included when building/linking
# This results in the error, "Op type not registered
# 'Addons>ImageProjectiveTransformV2' in binary" when running on borg TPUs
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
return policy
@staticmethod
def policy_simple():
"""Same as `policy_v0`, except with custom ops removed."""
policy = [
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
]
return policy
@staticmethod
def policy_test():
"""Autoaugment test policy for debugging."""
policy = [
[('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
]
return policy
class RandAugment(ImageAugment):
"""Applies the RandAugment policy to images.
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
"""
def __init__(self,
num_layers: int = 2,
magnitude: float = 10.,
cutout_const: float = 40.,
translate_const: float = 100.):
"""Applies the RandAugment policy to images.
Args:
num_layers: Integer, the number of augmentation transformations to apply
sequentially to an image. Represented as (N) in the paper. Usually best
values will be in the range [1, 3].
magnitude: Integer, shared magnitude across all augmentation operations.
Represented as (M) in the paper. Usually best values are in the range
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
"""
super(RandAugment, self).__init__()
self.num_layers = num_layers
self.magnitude = float(magnitude)
self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const)
self.available_ops = [
'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
]
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the RandAugment policy to `image`.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
The augmented version of `image`.
"""
input_image_type = image.dtype
if input_image_type != tf.uint8:
image = tf.clip_by_value(image, 0.0, 255.0)
image = tf.cast(image, dtype=tf.uint8)
replace_value = [128] * 3
min_prob, max_prob = 0.2, 0.8
for _ in range(self.num_layers):
op_to_select = tf.random.uniform([],
maxval=len(self.available_ops) + 1,
dtype=tf.int32)
branch_fns = []
for (i, op_name) in enumerate(self.available_ops):
prob = tf.random.uniform([],
minval=min_prob,
maxval=max_prob,
dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const,
self.translate_const)
branch_fns.append((
i,
# pylint:disable=g-long-lambda
lambda selected_func=func, selected_args=args: selected_func(
image, *selected_args)))
# pylint:enable=g-long-lambda
image = tf.switch_case(
branch_index=op_to_select,
branch_fns=branch_fns,
default=lambda: tf.identity(image))
image = tf.cast(image, dtype=input_image_type)
return image
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for autoaugment."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from official.legacy.image_classification import augment
def get_dtype_test_cases():
return [
('uint8', tf.uint8),
('int32', tf.int32),
('float16', tf.float16),
('float32', tf.float32),
]
@parameterized.named_parameters(get_dtype_test_cases())
class TransformsTest(parameterized.TestCase, tf.test.TestCase):
"""Basic tests for fundamental transformations."""
def test_to_from_4d(self, dtype):
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
original_ndims = len(shape)
image = tf.zeros(shape, dtype=dtype)
image_4d = augment.to_4d(image)
self.assertEqual(4, tf.rank(image_4d))
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
translation = [0, 0]
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.translate(image, translation))
def test_translate_invalid_translation(self, dtype):
image = tf.zeros((1, 1), dtype=dtype)
invalid_translation = [[[1, 1]]]
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
_ = augment.translate(image, invalid_translation)
def test_rotate(self, dtype):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
degrees = 0.
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase):
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Common modules for callbacks."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from typing import Any, List, MutableMapping, Optional, Text
from absl import logging
import tensorflow as tf
from official.modeling import optimization
from official.utils.misc import keras_utils
def get_callbacks(
model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: Optional[str] = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if backup_and_restore:
backup_dir = os.path.join(model_dir, 'tmp')
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir))
if include_tensorboard:
callbacks.append(
CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights,
profile_batch=0))
if time_history:
callbacks.append(
keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(model_dir, 'average',
'model.ckpt-{epoch:04d}')
callbacks.append(
AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks
def get_scalar_from_tensor(t: tf.Tensor) -> int:
"""Utility function to convert a Tensor to a scalar."""
t = tf.keras.backend.get_value(t)
if callable(t):
return t()
else:
return t
class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
"""A customized TensorBoard callback that tracks additional datapoints.
Metrics tracked:
- Global learning rate
Attributes:
log_dir: the path of the directory where to save the log files to be parsed
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key is
`period`.
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
def __init__(self,
log_dir: str,
track_lr: bool = False,
initial_step: int = 0,
**kwargs):
super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
self.step = initial_step
self._track_lr = track_lr
def on_batch_begin(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
self.step += 1
if logs is None:
logs = {}
logs.update(self._calculate_metrics())
super(CustomTensorBoard, self).on_batch_begin(epoch, logs)
def on_epoch_begin(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
for k, v in metrics.items():
logging.info('Current %s: %f', k, v)
super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)
def on_epoch_end(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {}
# TODO(b/149030439): disable LR reporting.
# if self._track_lr:
# logs['learning_rate'] = self._calculate_lr()
return logs
def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step."""
return get_scalar_from_tensor(
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model."""
optimizer = self.model.optimizer
# The optimizer might be wrapped by another class, so unwrap it
while hasattr(optimizer, '_optimizer'):
optimizer = optimizer._optimizer # pylint:disable=protected-access
return optimizer
class MovingAverageCallback(tf.keras.callbacks.Callback):
"""A Callback to be used with a `ExponentialMovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def __init__(self, overwrite_weights_on_train_end: bool = False, **kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimization.ExponentialMovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights to the model, and
save them. If False, keep the old non-averaged weights, but the saved
model uses the average weights. See `tf.keras.callbacks.ModelCheckpoint`
for the other args.
"""
def __init__(self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(filepath, monitor, verbose, save_best_only,
save_weights_only, mode, save_freq, **kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimization.ExponentialMovingAverage):
raise TypeError('AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
assert isinstance(self.model.optimizer,
optimization.ExponentialMovingAverage)
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
self.model.set_weights(non_avg_weights)
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Runs an Image Classification model."""
import os
import pprint
from typing import Any, Mapping, Optional, Text, Tuple
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.legacy.image_classification import callbacks as custom_callbacks
from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification import optimizer_factory
from official.legacy.image_classification.configs import base_configs
from official.legacy.image_classification.configs import configs
from official.legacy.image_classification.efficientnet import efficientnet_model
from official.legacy.image_classification.resnet import common
from official.legacy.image_classification.resnet import resnet_model
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import keras_utils
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
"""Get a dict of available metrics to track."""
if one_hot:
return {
# (name, metric_fn)
'acc':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.TopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
else:
return {
# (name, metric_fn)
'acc':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
def get_image_size_from_model(
params: base_configs.ExperimentConfig) -> Optional[int]:
"""If the given model has a preferred image size, return it."""
if params.model_name == 'efficientnet':
efficientnet_name = params.model.model_params.model_name
if efficientnet_name in efficientnet_model.MODEL_CONFIGS:
return efficientnet_model.MODEL_CONFIGS[efficientnet_name].resolution
return None
def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy,
one_hot: bool) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders."""
if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else:
logging.warning('label_smoothing not applied, so datasets will not be one '
'hot encoded.')
num_devices = strategy.num_replicas_in_sync if strategy else 1
image_size = get_image_size_from_model(params)
dataset_configs = [params.train_dataset, params.validation_dataset]
builders = []
for config in dataset_configs:
if config is not None and config.has_data:
builder = dataset_factory.DatasetBuilder(
config,
image_size=image_size or config.image_size,
num_devices=num_devices,
one_hot=one_hot)
else:
builder = None
builders.append(builder)
return builders
def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations."""
loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic':
return loss_scale
elif loss_scale is not None:
return float(loss_scale)
elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1.
else:
assert params.train_dataset.dtype == 'float16'
return fp16_default
def _get_params_from_flags(flags_obj: flags.FlagValues):
"""Get ParamsDict from flags."""
model = flags_obj.model_type.lower()
dataset = flags_obj.dataset.lower()
params = configs.get_config(model=model, dataset=dataset)
flags_overrides = {
'model_dir': flags_obj.model_dir,
'mode': flags_obj.mode,
'model': {
'name': model,
},
'runtime': {
'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu,
},
'train_dataset': {
'data_dir': flags_obj.data_dir,
},
'validation_dataset': {
'data_dir': flags_obj.data_dir,
},
'train': {
'time_history': {
'log_steps': flags_obj.log_steps,
},
},
}
overriding_configs = (flags_obj.config_file, flags_obj.params_override,
flags_overrides)
pp = pprint.PrettyPrinter()
logging.info('Base params: %s', pp.pformat(params.as_dict()))
for param in overriding_configs:
logging.info('Overriding params: %s', param)
params = hyperparams.override_params_dict(params, param, is_strict=True)
params.validate()
params.lock()
logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
return params
def resume_from_checkpoint(model: tf.keras.Model, model_dir: str,
train_steps: int) -> int:
"""Resumes from the latest checkpoint, if possible.
Loads the model weights and optimizer settings from a checkpoint.
This function should be used in case of preemption recovery.
Args:
model: The model whose weights should be restored.
model_dir: The directory where model weights were saved.
train_steps: The number of steps to train.
Returns:
The epoch of the latest checkpoint, or 0 if not restoring.
"""
logging.info('Load from checkpoint is enabled.')
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
logging.info('latest_checkpoint: %s', latest_checkpoint)
if not latest_checkpoint:
logging.info('No checkpoint detected.')
return 0
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint)
model.load_weights(latest_checkpoint)
initial_epoch = model.optimizer.iterations // train_steps
logging.info('Completed loading from checkpoint.')
logging.info('Resuming from epoch %d', initial_epoch)
return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype)
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
data_format = 'channels_last'
tf.keras.backend.set_image_data_format(data_format)
if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True)
if tf.config.list_physical_devices('GPU'):
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime
.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
def define_classifier_flags():
"""Defines common flags for image classification."""
hyperparams_flags.initialize_common_flags()
flags.DEFINE_string(
'data_dir', default=None, help='The location of the input data.')
flags.DEFINE_string(
'mode',
default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool(
'run_eagerly',
default=None,
help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string(
'model_type',
default=None,
help='The type of the model, e.g. EfficientNet, etc.')
flags.DEFINE_string(
'dataset',
default=None,
help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
"""Serializes and saves the experiment config."""
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def train_and_eval(
params: base_configs.ExperimentConfig,
strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
label_smoothing = params.model.loss.label_smoothing
one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [
builder.build(strategy) if builder else None for builder in builders
]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset = datasets
train_epochs = params.train.epochs
train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope:
model_params = params.model.model_params.as_dict()
model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate,
batch_size=train_builder.global_batch_size,
train_epochs=train_epochs,
train_steps=train_steps)
optimizer = optimizer_factory.build_optimizer(
optimizer_name=params.model.optimizer.name,
base_learning_rate=learning_rate,
params=params.model.optimizer.as_dict(),
model=model)
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params))
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
steps_per_loop = train_steps if params.train.set_epoch_loop else 1
if one_hot:
loss_obj = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=params.model.loss.label_smoothing)
else:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(
optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
steps_per_execution=steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
initial_epoch = resume_from_checkpoint(
model=model, model_dir=params.model_dir, train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir,
backup_and_restore=params.train.callbacks.enable_backup_and_restore)
serialize_config(params=params, model_dir=params.model_dir)
if params.evaluation.skip_eval:
validation_kwargs = {}
else:
validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit(
train_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
verbose=2,
**validation_kwargs)
validation_output = None
if not params.evaluation.skip_eval:
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history, validation_output, callbacks)
return stats
def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality."""
logging.info('Exporting model.')
model_params = params.model.model_params.as_dict()
model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint
if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest '
'checkpoint from model_dir.')
checkpoint = tf.train.latest_checkpoint(params.model_dir)
model.load_weights(checkpoint)
model.save(params.export.destination)
def run(flags_obj: flags.FlagValues,
strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]:
"""Runs Image Classification model using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training/eval stats
"""
params = _get_params_from_flags(flags_obj)
if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override)
elif params.mode == 'export_only':
export(params)
else:
raise ValueError('{} is not a valid mode.'.format(params.mode))
def main(_):
stats = run(flags.FLAGS)
if stats:
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_classifier_flags()
flags.mark_flag_as_required('data_dir')
flags.mark_flag_as_required('mode')
flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
import sys
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.legacy.image_classification import classifier_trainer
from official.utils.flags import core as flags_core
classifier_trainer.define_classifier_flags()
def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
"""Returns the combinations of end-to-end tests to run."""
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
model=[
'efficientnet',
'resnet',
],
dataset=[
'imagenet',
],
)
def get_params_override(params_override: Mapping[str, Any]) -> str:
"""Converts params_override dict to string command."""
return '--params_override=' + json.dumps(params_override)
def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing."""
return {
'train_dataset': {
'builder': 'synthetic',
'use_per_replica_batch_size': True,
'batch_size': 1,
'image_size': 224,
'dtype': dtype,
},
'validation_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
'dtype': dtype,
},
'train': {
'steps': 1,
'epochs': 1,
'callbacks': {
'enable_checkpoint_and_export': True,
'enable_tensorboard': False,
},
},
'evaluation': {
'steps': 1,
},
}
@flagsaver.flagsaver
def run_end_to_end(main: Callable[[Any], None],
extra_flags: Optional[Iterable[str]] = None,
model_dir: Optional[str] = None):
"""Runs the classifier trainer end-to-end."""
extra_flags = [] if extra_flags is None else extra_flags
args = [sys.argv[0], '--model_dir', model_dir] + extra_flags
flags_core.parse_flags(argv=args)
main(flags.FLAGS)
class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for Keras models."""
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(ClassifierTest, cls).setUpClass()
def tearDown(self):
super(ClassifierTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_train_and_eval(self, distribution, model, dataset):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override()),
'--mode=train_and_eval',
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export')
export_params['export'] = {}
export_params['export']['destination'] = export_path
export_flags = base_flags + [
'--mode=export_only',
get_params_override(export_params)
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
run_end_to_end(main=run, extra_flags=export_flags, model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`."""
model_dir = self.create_tempdir().full_path
extra_flags = [
'--data_dir=not_used',
'--mode=invalid_mode',
'--model_type=' + model,
'--dataset=' + dataset,
get_params_override(basic_params_override()),
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
with self.assertRaises(ValueError):
run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
from absl.testing import parameterized
import tensorflow as tf
from official.legacy.image_classification import classifier_trainer
from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification import test_utils
from official.legacy.image_classification.configs import base_configs
def get_trivial_model(num_classes: int) -> tf.keras.Model:
"""Creates and compiles trivial model for ImageNet dataset."""
model = test_utils.trivial_model(num_classes=num_classes)
lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
return model
def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return image, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset
class UtilTests(parameterized.TestCase, tf.test.TestCase):
"""Tests for individual utility functions within classifier_trainer.py."""
@parameterized.named_parameters(
('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
('resnet', 'resnet', '', None),
)
def test_get_model_size(self, model, model_name, expected):
config = base_configs.ExperimentConfig(
model_name=model,
model=base_configs.ModelConfig(
model_params={
'model_name': model_name,
},))
size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected)
@parameterized.named_parameters(
('dynamic', 'dynamic', None, 'dynamic'),
('scalar', 128., None, 128.),
('float32', None, 'float32', 1),
('float16', None, 'float16', 128),
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
@parameterized.named_parameters(('float16', 'float16'),
('bfloat16', 'bfloat16'))
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
run_eagerly=False,
enable_xla=False,
per_gpu_thread_count=1,
gpu_thread_mode='gpu_private',
num_gpus=1,
dataset_num_private_threads=1,
),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig(),
)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Get the model, datasets, and compile it.
model = get_trivial_model(10)
# Create the checkpoint
model_dir = self.create_tempdir().full_path
train_epochs = 1
train_steps = 10
ds = get_trivial_data()
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
save_weights_only=True)
]
model.fit(
ds,
callbacks=callbacks,
epochs=train_epochs,
steps_per_epoch=train_steps)
# Test load from checkpoint
clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model, model_dir=model_dir, train_steps=train_steps)
self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights())
tf.io.gfile.rmtree(model_dir)
def test_serialize_config(self):
"""Tests functionality for serializing data."""
config = base_configs.ExperimentConfig()
model_dir = self.create_tempdir().full_path
classifier_trainer.serialize_config(params=config, model_dir=model_dir)
saved_params_path = os.path.join(model_dir, 'params.yaml')
self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for high level configuration groups.."""
import dataclasses
from typing import Any, List, Optional
from official.core import config_definitions
from official.modeling import hyperparams
RuntimeConfig = config_definitions.RuntimeConfig
@dataclasses.dataclass
class TensorBoardConfig(hyperparams.Config):
"""Configuration for TensorBoard.
Attributes:
track_lr: Whether or not to track the learning rate in TensorBoard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
TensorBoard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(hyperparams.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class ExportConfig(hyperparams.Config):
"""Configuration for exports.
Attributes:
checkpoint: the path to the checkpoint to export.
destination: the path to where the checkpoint should be exported.
"""
checkpoint: str = None
destination: str = None
@dataclasses.dataclass
class MetricsConfig(hyperparams.Config):
"""Configuration for Metrics.
Attributes:
accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
None.
"""
accuracy: bool = None
top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(hyperparams.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass
class TrainConfig(hyperparams.Config):
"""Configuration for training.
Attributes:
resume_checkpoint: Whether or not to enable load checkpoint loading.
Defaults to None.
epochs: The number of training epochs to run. Defaults to None.
steps: The number of steps to run per epoch. If None, then this will be
inferred based on the number of images and batch size. Defaults to None.
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorBoardConfig.
set_epoch_loop: Whether or not to set `steps_per_execution` to
equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end
TPU training time.
"""
resume_checkpoint: bool = None
epochs: int = None
steps: int = None
callbacks: CallbacksConfig = CallbacksConfig()
metrics: MetricsConfig = None
tensorboard: TensorBoardConfig = TensorBoardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
set_epoch_loop: bool = False
@dataclasses.dataclass
class EvalConfig(hyperparams.Config):
"""Configuration for evaluation.
Attributes:
epochs_between_evals: The number of train epochs to run between evaluations.
Defaults to None.
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
None.
skip_eval: Whether or not to skip evaluation.
"""
epochs_between_evals: int = None
steps: int = None
skip_eval: bool = False
@dataclasses.dataclass
class LossConfig(hyperparams.Config):
"""Configuration for Loss.
Attributes:
name: The name of the loss. Defaults to None.
label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'.
"""
name: str = None
label_smoothing: float = None
@dataclasses.dataclass
class OptimizerConfig(hyperparams.Config):
"""Configuration for Optimizers.
Attributes:
name: The name of the optimizer. Defaults to None.
decay: Decay or rho, discounting factor for gradient. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to None.
momentum: Plain momentum constant. Defaults to None.
nesterov: Whether or not to apply Nesterov momentum. Defaults to None.
moving_average_decay: The amount of decay to apply. If 0 or None, then
exponential moving average is not used. Defaults to None.
lookahead: Whether or not to apply the lookahead optimizer. Defaults to
None.
beta_1: The exponential decay rate for the 1st moment estimates. Used in the
Adam optimizers. Defaults to None.
beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
Adam optimizers. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
"""
name: str = None
decay: float = None
epsilon: float = None
momentum: float = None
nesterov: bool = None
moving_average_decay: Optional[float] = None
lookahead: Optional[bool] = None
beta_1: float = None
beta_2: float = None
epsilon: float = None
@dataclasses.dataclass
class LearningRateConfig(hyperparams.Config):
"""Configuration for learning rates.
Attributes:
name: The name of the learning rate. Defaults to None.
initial_lr: The initial learning rate. Defaults to None.
decay_epochs: The number of decay epochs. Defaults to None.
decay_rate: The rate of decay. Defaults to None.
warmup_epochs: The number of warmup epochs. Defaults to None.
batch_lr_multiplier: The multiplier to apply to the base learning rate, if
necessary. Defaults to None.
examples_per_epoch: the number of examples in a single epoch. Defaults to
None.
boundaries: boundaries used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
name: str = None
initial_lr: float = None
decay_epochs: float = None
decay_rate: float = None
warmup_epochs: int = None
examples_per_epoch: int = None
boundaries: List[int] = None
multipliers: List[float] = None
scale_by_batch_size: float = 0.
staircase: bool = None
@dataclasses.dataclass
class ModelConfig(hyperparams.Config):
"""Configuration for Models.
Attributes:
name: The name of the model. Defaults to None.
model_params: The parameters used to create the model. Defaults to None.
num_classes: The number of classes in the model. Defaults to None.
loss: A `LossConfig` instance. Defaults to None.
optimizer: An `OptimizerConfig` instance. Defaults to None.
"""
name: str = None
model_params: hyperparams.Config = None
num_classes: int = None
loss: LossConfig = None
optimizer: OptimizerConfig = None
@dataclasses.dataclass
class ExperimentConfig(hyperparams.Config):
"""Base configuration for an image classification experiment.
Attributes:
model_dir: The directory to use when running an experiment.
mode: e.g. 'train_and_eval', 'export'
runtime: A `RuntimeConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
export: An `ExportConfig` instance.
"""
model_dir: str = None
model_name: str = None
mode: str = None
runtime: RuntimeConfig = None
train_dataset: Any = None
validation_dataset: Any = None
train: TrainConfig = None
evaluation: EvalConfig = None
model: ModelConfig = None
export: ExportConfig = None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration utils for image classification experiments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dataclasses
from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification.configs import base_configs
from official.legacy.image_classification.efficientnet import efficientnet_config
from official.legacy.image_classification.resnet import resnet_config
@dataclasses.dataclass
class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
"""Base configuration to train efficientnet-b0 on ImageNet.
Attributes:
export: An `ExportConfig` instance
runtime: A `RuntimeConfig` instance.
dataset: A `DatasetConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
"""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='train')
validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='validation')
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=500,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = efficientnet_config.EfficientNetModelConfig(
)
@dataclasses.dataclass
class ResNetImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train resnet-50 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='train',
one_hot=False,
mean_subtract=True,
standardize=True)
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map = {
'imagenet': {
'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(),
}
}
try:
return dataset_model_config_map[dataset][model]
except KeyError:
if dataset not in dataset_model_config_map:
raise KeyError('Invalid dataset received. Received: {}. Supported '
'datasets include: {}'.format(
dataset, ', '.join(dataset_model_config_map.keys())))
raise KeyError('Invalid model received. Received: {}. Supported models for'
'{} include: {}'.format(
model, dataset,
', '.join(dataset_model_config_map[dataset].keys())))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment