"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "0886b384e7d78c56bb23b5be41e17e5447231eca"
Unverified Commit 5ffcc5b6 authored by Anirudh Vegesana's avatar Anirudh Vegesana Committed by GitHub
Browse files

Merge branch 'purdue-yolo' into detection_generator_pr

parents 0b81a843 76e0c014
...@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g ...@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g
## License ## License
[Apache License 2.0](LICENSE) [Apache License 2.0](LICENSE)
## Citing TensorFlow Model Garden
If you use TensorFlow Model Garden in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
...@@ -40,7 +40,7 @@ In the near future, we will add: ...@@ -40,7 +40,7 @@ In the near future, we will add:
| Model | Reference (Paper) | | Model | Reference (Paper) |
|-------|-------------------| |-------|-------------------|
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | | [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) | | [ResNet](vision/beta/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [ResNet-RS](vision/beta/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) | | [ResNet-RS](vision/beta/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) | | [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
...@@ -48,10 +48,10 @@ In the near future, we will add: ...@@ -48,10 +48,10 @@ In the near future, we will add:
| Model | Reference (Paper) | | Model | Reference (Paper) |
|-------|-------------------| |-------|-------------------|
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/beta/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/beta/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) | | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) | | [SpineNet](vision/beta/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
...@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly ...@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
## Citing TF Official Model Garden
To cite this repository:
```
@software{tfmodels2020github,
author = {Chen Chen and Xianzhi Du and Le Hou and Jaeyoun Kim and Jing Li and
Yeqing Li and Abdullah Rashwan and Fan Yang and Hongkun Yu},
title = {TensorFlow Official Model Garden},
url = {https://github.com/tensorflow/models/tree/master/official},
year = {2020},
}
```
...@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy: a string specifying which distribution strategy to distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored", use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case "parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
use TPUStrategy using `tpu_address`. "off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and all-reduce. For `MirroredStrategy`, valid values are "nccl" and
...@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus > 1: if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy " raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus)) "flag cannot be set to `off`.".format(num_gpus))
return None # Return the default distribution strategy.
return tf.distribute.get_strategy()
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs. # When tpu_address is an empty string, we communicate with local TPUs.
......
...@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase):
def test_no_strategy(self): def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off') ds = distribute_utils.get_distribution_strategy('off')
self.assertIsNone(ds) self.assertIs(ds, tf.distribute.get_strategy())
def test_invalid_strategy(self): def test_invalid_strategy(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
......
...@@ -18,9 +18,27 @@ from absl import flags ...@@ -18,9 +18,27 @@ from absl import flags
def define_flags(): def define_flags():
"""Defines flags.""" """Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string( flags.DEFINE_string(
'experiment', default=None, help='The experiment type registered.') 'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'mode',
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import os
from typing import List
import gin
import orbit
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import base_trainer
from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def __init__(
self,
export_dir: str,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self._optimizer = optimizer
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
self.update_pruning_step.set_model(model)
self.update_pruning_step.on_train_begin()
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
log_dir=export_dir)
model.optimizer = optimizer
self.pruning_summaries.set_model(model)
def __call__(self, output: orbit.runner.Output):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
class EMACheckpointing:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if not isinstance(optimizer, optimization.ExponentialMovingAverage):
raise ValueError('Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_dir,
max_to_keep=max_to_keep,
checkpoint_name='average_weights')
def __call__(self, output: orbit.runner.Output):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
self._optimizer.swap_weights()
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
eval_actions = []
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
eval_actions.append(
EMACheckpointing(
export_dir=model_dir,
optimizer=trainer.optimizer,
checkpoint=trainer.checkpoint,
max_to_keep=params.trainer.max_to_keep))
return eval_actions
@gin.configurable
def get_train_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
return train_actions
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
def __init__(self):
self.value = tf.Variable(0)
@tf.function(input_signature=[])
def __call__(self):
return self.value
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
optimizer = optimization.ExponentialMovingAverage(
optimizer, trainable_weights_only=False)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer.shadow_copy(model)
checkpoint = tf.train.Checkpoint(model=model)
# Changes model.value to 3, average value is still 0.
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
if __name__ == '__main__':
tf.test.main()
...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps # Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations. # avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale: if runtime_config:
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
......
...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}, },
}))) })))
trainer = self.create_test_trainer(config) trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype == 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else:
self.assertIsInstance(trainer.optimizer, self.assertIsInstance(trainer.optimizer,
tf.keras.mixed_precision.LossScaleOptimizer) tf.keras.mixed_precision.LossScaleOptimizer)
if loss_scale in (None, 'dynamic'):
self.assertTrue(trainer.optimizer.dynamic)
else:
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config): ...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets. """The base configuration for building datasets.
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a str indicating input_path: The path to the input. It can be either (1) a str indicating a
a file path/pattern, or (2) a str indicating multiple file paths/patterns file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
(3) a list of str, each of which is a file path/pattern or multiple file str, each of which is a file path/pattern or multiple file paths/patterns
paths/patterns separated by comma. separated by comma, or (4) a dictionary of the previous three approaches
It should not be specified when the following `tfds_name` is specified. for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified. specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It tfds_split: A str indicating which split of the data to load from TFDS. It
...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config): ...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data. shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the re-reading from disk, re-decoding and re-parsing the example on the second
second epoch, but it requires significant memory overhead. epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when cycle_length: The number of files that will be processed concurrently when
interleaving files. interleaving files.
block_length: The number of consecutive elements to produce from each input block_length: The number of consecutive elements to produce from each input
...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config): ...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
tf_data_service_address: The URI of a tf.data service to offload tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary. overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This tf_data_service_job_name: The name of the tf.data service job. This argument
argument makes it possible for multiple datasets to share the same job. makes it possible for multiple datasets to share the same job. The default
The default behavior is that the dataset creates anonymous, exclusively behavior is that the dataset creates anonymous, exclusively owned jobs.
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data. tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label) returned tf.data.Dataset will have a 2-tuple structure (input, label)
...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config): ...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance. performance.
seed: An optional seed to use for deterministic shuffling/preprocessing. seed: An optional seed to use for deterministic shuffling/preprocessing.
""" """
input_path: Union[Sequence[str], str] = "" input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = "" tfds_name: str = ""
tfds_split: str = "" tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
......
...@@ -82,7 +82,7 @@ def export(export_module: ExportModule, ...@@ -82,7 +82,7 @@ def export(export_module: ExportModule,
The savedmodel directory path. The savedmodel directory path.
""" """
ckpt_dir_or_file = checkpoint_path ckpt_dir_or_file = checkpoint_path
if tf.io.gfile.isdir(ckpt_dir_or_file): if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file: if ckpt_dir_or_file:
checkpoint = tf.train.Checkpoint(model=export_module.model) checkpoint = tf.train.Checkpoint(model=export_module.model)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Union, Dict, Sequence
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -45,6 +45,7 @@ class InputReader: ...@@ -45,6 +45,7 @@ class InputReader:
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None, sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[ transform_and_batch_fn: Optional[Callable[
...@@ -59,6 +60,9 @@ class InputReader: ...@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary. and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn. decoded raw tensors dict before the parser_fn.
...@@ -78,10 +82,23 @@ class InputReader: ...@@ -78,10 +82,23 @@ class InputReader:
raise ValueError('At most one of `input_path` and `tfds_name` can be ' raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % 'specified, but got %s and %s.' %
(params.input_path, params.tfds_name)) (params.input_path, params.tfds_name))
if isinstance(params.input_path,
cfg.base_config.Config) and combine_fn is None:
raise ValueError(
'A `combine_fn` is required if the `input_path` is a dictionary.')
self._tfds_builder = None self._tfds_builder = None
self._matched_files = [] self._matched_files = None
if params.input_path: if params.input_path:
self._matched_files = self._match_files(params.input_path) # we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else: else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
...@@ -106,6 +123,7 @@ class InputReader: ...@@ -106,6 +123,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._combine_fn = combine_fn
self._sample_fn = sample_fn self._sample_fn = sample_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
...@@ -131,7 +149,7 @@ class InputReader: ...@@ -131,7 +149,7 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]: def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path.""" """Matches files from an input_path."""
matched_files = [] matched_files = []
# Read dataset from files. # Read dataset from files.
...@@ -195,8 +213,8 @@ class InputReader: ...@@ -195,8 +213,8 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1): 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -231,8 +249,8 @@ class InputReader: ...@@ -231,8 +249,8 @@ class InputReader:
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1): 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -281,42 +299,53 @@ class InputReader: ...@@ -281,42 +299,53 @@ class InputReader:
def _read_decode_and_parse_dataset( def _read_decode_and_parse_dataset(
self, self,
matched_files: List[str], matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn, dataset_fn,
batch_size: int, batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None, input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset: tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing.""" """Returns a tf.data.Dataset object after reading, decoding, and parsing."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1:
if input_context and (len(files) < input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
return self._shard_files_then_read(files, dataset_fn, input_context)
elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.')
def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Decode
ds = _maybe_map_fn(ds, self._decoder_fn)
return ds
if tfds_builder: if tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(matched_files) > 1: dataset = _shuffle_and_decode(dataset)
if input_context and (len(matched_files) < elif isinstance(matched_files, (list, tuple)):
input_context.num_input_pipelines): dataset = _files_to_dataset(matched_files)
logging.warn( dataset = _shuffle_and_decode(dataset)
'The number of files %d is less than the number of input pipelines ' elif isinstance(matched_files, dict):
'%d. We will send all input files to every worker. ' datasets = {}
'Please consider sharding your data into more files.', for k, fs in matched_files.items():
len(matched_files), input_context.num_input_pipelines) datasets[k] = _files_to_dataset(fs)
dataset = self._read_files_then_shard(matched_files, datasets[k] = _shuffle_and_decode(datasets[k])
dataset_fn, dataset = self._combine_fn(datasets)
input_context)
else:
dataset = self._shard_files_then_read(matched_files,
dataset_fn,
input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('`matched_files` should be a list or dict.')
'there is also no `matched_files`.')
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
...@@ -333,8 +362,7 @@ class InputReader: ...@@ -333,8 +362,7 @@ class InputReader:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size batch_size) if input_context else batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder per_replica_batch_size, drop_remainder=self._drop_remainder)
)
return dataset return dataset
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, Mapping, Tuple, Optional from typing import Any, Mapping, Optional, Tuple
# Import libraries # Import libraries
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import actions
from official.core import base_task from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
...@@ -38,7 +40,8 @@ def run_experiment( ...@@ -38,7 +40,8 @@ def run_experiment(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -54,6 +57,8 @@ def run_experiment( ...@@ -54,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope(). strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -73,6 +78,8 @@ def run_experiment( ...@@ -73,6 +78,8 @@ def run_experiment(
params, model_dir)) params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, trainer.checkpoint,
directory=model_dir, directory=model_dir,
...@@ -85,7 +92,7 @@ def run_experiment( ...@@ -85,7 +92,7 @@ def run_experiment(
else: else:
checkpoint_manager = None checkpoint_manager = None
controller = orbit.Controller( controller = controller_cls(
strategy=distribution_strategy, strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None, trainer=trainer if 'train' in mode else None,
evaluator=trainer, evaluator=trainer,
...@@ -97,7 +104,9 @@ def run_experiment( ...@@ -97,7 +104,9 @@ def run_experiment(
params.trainer.validation_summary_subdir) if params.trainer.validation_summary_subdir) if
(save_summary) else None, (save_summary) else None,
summary_interval=params.trainer.summary_interval if summary_interval=params.trainer.summary_interval if
(save_summary) else None) (save_summary) else None,
train_actions=actions.get_train_actions(params, trainer, model_dir),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope(): with distribution_strategy.scope():
...@@ -129,6 +138,11 @@ def run_experiment( ...@@ -129,6 +138,11 @@ def run_experiment(
logging.info('Number of trainable params in model: %f Millions.', logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6) num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval: if run_post_eval:
with distribution_strategy.scope(): with distribution_strategy.scope():
return trainer.model, trainer.evaluate( return trainer.model, trainer.evaluate(
......
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import json import json
import os import os
import pprint import pprint
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -25,6 +25,9 @@ import gin ...@@ -25,6 +25,9 @@ import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
...@@ -241,6 +244,9 @@ class ParseConfigOptions: ...@@ -241,6 +244,9 @@ class ParseConfigOptions:
def parse_configuration(flags_obj, lock_return=True, print_return=True): def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags.""" """Parses ExperimentConfig from flags."""
if flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
# 1. Get the default config from the registered experiment. # 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment) params = exp_factory.get_exp_config(flags_obj.experiment)
...@@ -285,7 +291,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -285,7 +291,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if print_return: if print_return:
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict())) pp.pformat(params.as_dict()))
return params return params
...@@ -294,6 +300,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -294,6 +300,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def serialize_config(params: config_definitions.ExperimentConfig, def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str): model_dir: str):
"""Serializes and saves the experiment config.""" """Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml') params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path) logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir) tf.io.gfile.makedirs(model_dir)
...@@ -388,3 +396,48 @@ def try_count_params(model: tf.keras.Model): ...@@ -388,3 +396,48 @@ def try_count_params(model: tf.keras.Model):
'train step already reached before this run.') 'train step already reached before this run.')
return None return None
return None return None
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adafactor optimizer.
A new optimizer that will be open sourced soon.
"""
# pylint: disable=invalid-name, represents an unimplemented class definition.
Adafactor = "Unimplemented"
...@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config): ...@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config):
values[0] [boundaries[0], boundaries[1]] -> values[1] values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n], [boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None. end] -> values[n+1] Defaults to None.
offset: An int. The offset applied to steps. Defaults to 0.
""" """
name: str = 'PiecewiseConstantDecay' name: str = 'PiecewiseConstantDecay'
boundaries: Optional[List[int]] = None boundaries: Optional[List[int]] = None
values: Optional[List[float]] = None values: Optional[List[float]] = None
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config): ...@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config):
decay_rate: A float. Defaults to None. decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False. intervals. Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
""" """
name: str = 'ExponentialDecay' name: str = 'ExponentialDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
decay_steps: Optional[int] = None decay_steps: Optional[int] = None
decay_rate: Optional[float] = None decay_rate: Optional[float] = None
staircase: Optional[bool] = None staircase: Optional[bool] = None
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -99,6 +103,7 @@ class PolynomialLrConfig(base_config.Config): ...@@ -99,6 +103,7 @@ class PolynomialLrConfig(base_config.Config):
power: A float. The power of the polynomial. Defaults to linear, 1.0. power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps. cycle: A boolean, whether or not it should cycle beyond decay_steps.
Defaults to False. Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
""" """
name: str = 'PolynomialDecay' name: str = 'PolynomialDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
...@@ -106,6 +111,7 @@ class PolynomialLrConfig(base_config.Config): ...@@ -106,6 +111,7 @@ class PolynomialLrConfig(base_config.Config):
end_learning_rate: float = 0.0001 end_learning_rate: float = 0.0001
power: float = 1.0 power: float = 1.0
cycle: bool = False cycle: bool = False
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config): ...@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config):
to None. to None.
alpha: A float. Minimum learning rate value as a fraction of alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate. initial_learning_rate.
offset: An int. The offset applied to steps. Defaults to 0.
""" """
name: str = 'CosineDecay' name: str = 'CosineDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
decay_steps: Optional[int] = None decay_steps: Optional[int] = None
alpha: float = 0.0 alpha: float = 0.0
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -52,6 +52,7 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -52,6 +52,7 @@ class OptimizerConfig(oneof.OneOfConfig):
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig() lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig() adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
slide: opt_cfg.SLIDEConfig = opt_cfg.SLIDEConfig() slide: opt_cfg.SLIDEConfig = opt_cfg.SLIDEConfig()
adafactor: opt_cfg.AdafactorConfig = opt_cfg.AdafactorConfig()
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -247,3 +247,22 @@ class SLIDEConfig(BaseOptimizerConfig): ...@@ -247,3 +247,22 @@ class SLIDEConfig(BaseOptimizerConfig):
do_gradient_rescaling: bool = True do_gradient_rescaling: bool = True
norm_type: str = "layer" norm_type: str = "layer"
ratio_clip_norm: float = 1e5 ratio_clip_norm: float = 1e5
@dataclasses.dataclass
class AdafactorConfig(BaseOptimizerConfig):
"""Configuration for Adafactor optimizer.
The attributes for this class matches the arguments of the Adafactor
implementation.
"""
name: str = "Adafactor"
factored: bool = True
multiply_by_parameter_scale: bool = True
beta1: Optional[float] = None
decay_rate: float = 0.8
step_offset: int = 0
clipping_threshold: float = 1.0
min_dim_size_to_factor: int = 128
epsilon1: float = 1e-30
epsilon2: float = 1e-3
...@@ -19,6 +19,75 @@ from typing import Mapping, Any, Union, Optional ...@@ -19,6 +19,75 @@ from typing import Mapping, Any, Union, Optional
import tensorflow as tf import tensorflow as tf
def _make_offset_wrapper(new_class_name: str, base_lr_class):
"""Generates a offset wrapper of learning rate schedule.
It will returns a subclass of the the `base_lr_class`, the subclass takes an
`offset` argument in the constructor. When the new class instance is called,
the behavior is:
new_class_object(step) = base_lr_class_object(step - offset)
Example:
CosineDecayWithOffset = _make_offset_wrapper(
'CosineDecayWithOffset', tf.keras.experimental.CosineDecay)
# Use the lr:
lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
decay_steps=1000)
lr(101) # equals to tf.keras.experimental.CosineDecay(...)(101-100)
Args:
new_class_name: the name of the new class.
base_lr_class: the base learning rate schedule class. Should be subclass of
tf.keras.optimizers.schedules.LearningRateSchedule
Returns:
A new class (subclass of the base_lr_class) that can take an offset.
"""
assert issubclass(base_lr_class,
tf.keras.optimizers.schedules.LearningRateSchedule), (
"base_lr_class should be subclass of keras "
f"LearningRateSchedule, got {base_lr_class}")
# pylint: disable=protected-access,pointless-statement
def offset_learning_rate_init(self, offset=0, **kwargs):
"""Construct learning rate schedule object.
When this object is called, its behavior is
self.__call__(step) == base_lr_class.__call__(step - offset)
Args:
self: this object.
offset: The offset when computing the learning rate schedule.
**kwargs: Pass through to base learning rate class constructor.
"""
base_lr_class.__init__(self, **kwargs)
self._offset = offset
def offset_learning_rate_call(self, step):
step = tf.cast(step - self._offset, tf.float32)
return base_lr_class.__call__(self, step)
# pylint: enable=protected-access,pointless-statement
return type(
new_class_name, (base_lr_class,), {
"base_lr_class": base_lr_class,
"__init__": offset_learning_rate_init,
"__call__": offset_learning_rate_call
})
PiecewiseConstantDecayWithOffset = _make_offset_wrapper(
"PiecewiseConstantDecayWithOffset",
tf.keras.optimizers.schedules.PiecewiseConstantDecay)
PolynomialDecayWithOffset = _make_offset_wrapper(
"PolynomialDecayWithOffset", tf.keras.optimizers.schedules.PolynomialDecay)
ExponentialDecayWithOffset = _make_offset_wrapper(
"ExponentialDecayWithOffset",
tf.keras.optimizers.schedules.ExponentialDecay)
CosineDecayWithOffset = _make_offset_wrapper("CosineDecayWithOffset",
tf.keras.experimental.CosineDecay)
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Linear warmup schedule.""" """Linear warmup schedule."""
......
...@@ -70,5 +70,40 @@ class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase): ...@@ -70,5 +70,40 @@ class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
class OffsetLearningRateTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
dict(class_name=lr_schedule.PiecewiseConstantDecayWithOffset),
dict(class_name=lr_schedule.PolynomialDecayWithOffset),
dict(class_name=lr_schedule.ExponentialDecayWithOffset),
dict(class_name=lr_schedule.CosineDecayWithOffset),
)
def test_generated_docstring(self, class_name):
self.assertNotEmpty(class_name.__init__.__doc__)
@parameterized.parameters(
dict(
class_name=lr_schedule.PiecewiseConstantDecayWithOffset,
kwarg=dict(boundaries=[50, 80], values=[1.0, 0.5, 0.1])),
dict(
class_name=lr_schedule.PolynomialDecayWithOffset,
kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
dict(
class_name=lr_schedule.ExponentialDecayWithOffset,
kwarg=dict(
initial_learning_rate=1.0, decay_steps=100, decay_rate=0.5)),
dict(
class_name=lr_schedule.CosineDecayWithOffset,
kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
)
def test_offset(self, class_name, kwarg):
offset = 10
offset_lr = class_name(offset=offset, **kwarg)
base_lr = class_name.base_lr_class(**kwarg)
self.assertIsInstance(offset_lr, class_name)
for step in range(10, 101, 10):
self.assertEqual(offset_lr(step), base_lr(step - offset))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment