Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utils for mock models and tasks."""
from typing import Dict, Text
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.multitask import base_model
class MockFooModel(tf.keras.Model):
"""A mock model can consume 'foo' and 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32))
if "foo" in inputs:
input_tensor = inputs["foo"]
else:
input_tensor = inputs["bar"]
return self._foo_specific_layer(self._share_layer(input_tensor))
class MockBarModel(tf.keras.Model):
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32))
return self._bar_specific_layer(self._share_layer(inputs["bar"]))
class MockMultiTaskModel(base_model.MultiTaskBaseModel):
def __init__(self, *args, **kwargs):
self._shared_dense = tf.keras.layers.Dense(1)
super().__init__(*args, **kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
return {
"foo": MockFooModel(self._shared_dense),
"bar": MockBarModel(self._shared_dense)
}
def mock_data(feature_name):
"""Mock dataset function."""
def _generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return {feature_name: x}, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
_generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
class FooConfig(cfg.TaskConfig):
pass
class BarConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(FooConfig)
class MockFooTask(base_task.Task):
"""Mock foo task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="foo_acc")]
def build_inputs(self, params):
return mock_data("foo")
def build_model(self) -> tf.keras.Model:
return MockFooModel(shared_layer=tf.keras.layers.Dense(1))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
@task_factory.register_task_cls(BarConfig)
class MockBarTask(base_task.Task):
"""Mock bar task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="bar_acc")]
def build_inputs(self, params):
return mock_data("bar")
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
TRAINERS = {
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
'joint': base_trainer.MultiTaskBaseTrainer
}
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel, mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
if params.trainer.trainer_type == 'interleaving':
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
evaluator = evaluator_lib.MultiTaskEvaluator(
task=task,
model=model,
global_step=trainer.global_step if is_training else None)
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=model.initialize)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
summary_interval=params.trainer.summary_interval)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
return model
def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: multitask.MultiTask,
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) -> tf.keras.Model:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
model: `tf.keras.Model` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
model = train_task.build_model()
if is_training:
trainer = core_lib.Trainer(
config=params,
task=train_task,
model=model,
optimizer=optimizer,
train=True,
evaluate=False)
else:
trainer = None
if is_eval:
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize if trainer else None)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return model, {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.train_lib."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import task_factory
from official.modeling.hyperparams import params_dict
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
from official.modeling.multitask import train_lib
class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel()
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=test_multitask,
model=model,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks)
train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=train_task,
eval_tasks=eval_tasks,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from official.modeling.optimization.configs.learning_rate_config import *
from official.modeling.optimization.configs.optimization_config import *
from official.modeling.optimization.configs.optimizer_config import *
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
from official.modeling.optimization.lr_schedule import *
from official.modeling.optimization.optimizer_factory import OptimizerFactory
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for learning rate schedule config."""
from typing import List, Optional
......@@ -50,16 +49,13 @@ class StepwiseLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries.
Defaults to None.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
[0, boundaries[0]] -> values[0]
[boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n]
[boundaries[n], end] -> values[n+1]
Defaults to None.
The learning rate is computed as follows: [0, boundaries[0]] ->
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None.
"""
name: str = 'PiecewiseConstantDecay'
boundaries: Optional[List[int]] = None
......@@ -74,10 +70,9 @@ class ExponentialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to ExponentialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
......@@ -97,10 +92,9 @@ class PolynomialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PolynomialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
end_learning_rate: A float. The minimal end learning rate.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
......@@ -123,10 +117,9 @@ class CosineLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to CosineDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
"""
......@@ -136,6 +129,66 @@ class CosineLrConfig(base_config.Config):
alpha: float = 0.0
@dataclasses.dataclass
class DirectPowerLrConfig(base_config.Config):
"""Configuration for DirectPower learning rate decay.
This class configures a schedule following follows lr * (step)^power.
Attributes:
name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
"""
name: str = 'DirectPowerDecay'
initial_learning_rate: Optional[float] = None
power: float = -0.5
@dataclasses.dataclass
class PowerAndLinearDecayLrConfig(base_config.Config):
"""Configuration for DirectPower learning rate decay.
This class configures a schedule following follows lr * (step)^power for the
first total_decay_steps * (1 - linear_decay_fraction) steps, and follows
lr * (step)^power * (total_decay_steps - step) / (total_decay_steps *
linear_decay_fraction) for the rest of the steps.
Attributes:
name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
"""
name: str = 'PowerAndLinearDecay'
initial_learning_rate: Optional[float] = None
total_decay_steps: Optional[int] = None
power: float = -0.5
linear_decay_fraction: float = 0.1
@dataclasses.dataclass
class PowerDecayWithOffsetLrConfig(base_config.Config):
"""Configuration for power learning rate decay with step offset.
Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
Otherwise, learning rate equals to lr * (step - offset)^power.
Attributes:
name: The name of the learning rate schedule.
Defaults to PowerDecayWithOffset.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
offset: An integer. Power decay happens after `offset` steps.
pre_offset_learning_rate: A float. The constant learning rate before
`offset` steps.
"""
name: str = 'PowerDecayWithOffset'
initial_learning_rate: Optional[float] = None
power: float = -0.5
offset: int = 0
pre_offset_learning_rate: float = 1.0e6
@dataclasses.dataclass
class LinearWarmupConfig(base_config.Config):
"""Configuration for linear warmup schedule config.
......@@ -173,4 +226,3 @@ class PolynomialWarmupConfig(base_config.Config):
name: str = 'polynomial'
power: float = 1
warmup_steps: Optional[int] = None
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimization configs.
This file define the dataclass for optimization configs (OptimizationConfig).
......@@ -40,6 +39,8 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw: adam with weight decay.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
lars: lars optimizer.
adagrad: adagrad optimizer.
"""
type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
......@@ -47,6 +48,8 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
@dataclasses.dataclass
......@@ -60,6 +63,10 @@ class LrConfig(oneof.OneOfConfig):
exponential: exponential learning rate config.
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
power: step^power learning rate config.
power_linear: learning rate config of step^power followed by
step^power*linear.
power_with_offset: power decay with a step offset.
"""
type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
......@@ -67,6 +74,11 @@ class LrConfig(oneof.OneOfConfig):
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
cosine: lr_cfg.CosineLrConfig = lr_cfg.CosineLrConfig()
power: lr_cfg.DirectPowerLrConfig = lr_cfg.DirectPowerLrConfig()
power_linear: lr_cfg.PowerAndLinearDecayLrConfig = (
lr_cfg.PowerAndLinearDecayLrConfig())
power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = (
lr_cfg.PowerDecayWithOffsetLrConfig())
@dataclasses.dataclass
......@@ -89,9 +101,12 @@ class OptimizationConfig(base_config.Config):
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
optimizer: OptimizerConfig = OptimizerConfig()
ema: Optional[opt_cfg.EMAConfig] = None
learning_rate: LrConfig = LrConfig()
warmup: WarmupConfig = WarmupConfig()
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimization_config.py."""
import tensorflow as tf
......@@ -26,15 +25,15 @@ class OptimizerConfigTest(tf.test.TestCase):
def test_no_optimizer(self):
optimizer = optimization_config.OptimizationConfig({}).optimizer.get()
self.assertEqual(optimizer, None)
self.assertIsNone(optimizer)
def test_no_lr_schedule(self):
lr = optimization_config.OptimizationConfig({}).learning_rate.get()
self.assertEqual(lr, None)
self.assertIsNone(lr)
def test_no_warmup_schedule(self):
warmup = optimization_config.OptimizationConfig({}).warmup.get()
self.assertEqual(warmup, None)
self.assertIsNone(warmup)
def test_config(self):
opt_config = optimization_config.OptimizationConfig({
......@@ -50,12 +49,11 @@ class OptimizerConfigTest(tf.test.TestCase):
'type': 'linear'
}
})
self.assertEqual(opt_config.optimizer.get(),
opt_cfg.SGDConfig())
self.assertEqual(opt_config.optimizer.get(), opt_cfg.SGDConfig())
self.assertEqual(opt_config.learning_rate.get(),
lr_cfg.PolynomialLrConfig())
self.assertEqual(opt_config.warmup.get(),
lr_cfg.LinearWarmupConfig())
self.assertEqual(opt_config.warmup.get(), lr_cfg.LinearWarmupConfig())
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimizer configs."""
from typing import List, Optional
......@@ -21,7 +20,24 @@ from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class SGDConfig(base_config.Config):
class BaseOptimizerConfig(base_config.Config):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
clipped so that their global norm is no higher than this value
"""
clipnorm: Optional[float] = None
clipvalue: Optional[float] = None
global_clipnorm: Optional[float] = None
@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
......@@ -39,7 +55,7 @@ class SGDConfig(base_config.Config):
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
class RMSPropConfig(BaseOptimizerConfig):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
......@@ -60,7 +76,25 @@ class RMSPropConfig(base_config.Config):
@dataclasses.dataclass
class AdamConfig(base_config.Config):
class AdagradConfig(BaseOptimizerConfig):
"""Configuration for Adagrad optimizer.
The attributes of this class match the arguments of
tf.keras.optimizer.Adagrad.
Attributes:
name: name of the optimizer.
initial_accumulator_value: A floating point value. Starting value for the
accumulators, must be non-negative.
epsilon: A small floating point value to avoid zero denominator.
"""
name: str = "Adagrad"
initial_accumulator_value: float = 0.1
epsilon: float = 1e-07
@dataclasses.dataclass
class AdamConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer.
The attributes for this class matches the arguments of
......@@ -82,7 +116,7 @@ class AdamConfig(base_config.Config):
@dataclasses.dataclass
class AdamWeightDecayConfig(base_config.Config):
class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay.
Attributes:
......@@ -95,8 +129,10 @@ class AdamWeightDecayConfig(base_config.Config):
weight_decay_rate: float. Weight decay rate. Default to 0.
include_in_weight_decay: list[str], or None. List of weight names to include
in weight decay.
include_in_weight_decay: list[str], or None. List of weight names to not
exclude_from_weight_decay: list[str], or None. List of weight names to not
include in weight decay.
gradient_clip_norm: A positive float. Clips the gradients to this maximum
L2-norm. Default to 1.0.
"""
name: str = "AdamWeightDecay"
beta_1: float = 0.9
......@@ -106,10 +142,11 @@ class AdamWeightDecayConfig(base_config.Config):
weight_decay_rate: float = 0.0
include_in_weight_decay: Optional[List[str]] = None
exclude_from_weight_decay: Optional[List[str]] = None
gradient_clip_norm: float = 1.0
@dataclasses.dataclass
class LAMBConfig(base_config.Config):
class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer.
The attributes for this class matches the arguments of
......@@ -122,12 +159,11 @@ class LAMBConfig(base_config.Config):
epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a
substring matching the pattern will be excluded.
weight decay. Variables whose name contain a substring matching the
pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name
contain a substring matching the pattern will
be excluded.
from layer adaptation. Variables whose name contain a substring matching
the pattern will be excluded.
"""
name: str = "LAMB"
beta_1: float = 0.9
......@@ -136,3 +172,53 @@ class LAMBConfig(base_config.Config):
weight_decay_rate: float = 0.0
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
@dataclasses.dataclass
class EMAConfig(BaseOptimizerConfig):
"""Exponential moving average optimizer config.
Attributes:
name: 'str', name of the optimizer.
average_decay: 'float', average decay value.
start_step: 'int', start step to apply moving average.
dynamic_decay: 'bool', whether to apply dynamic decay or not.
"""
name: str = "ExponentialMovingAverage"
average_decay: float = 0.99
start_step: int = 0
dynamic_decay: bool = True
@dataclasses.dataclass
class LARSConfig(BaseOptimizerConfig):
"""Layer-wise adaptive rate scaling config.
Attributes:
name: 'str', name of the optimizer.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0.9.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the highest
scaling factor in LARS..
weight_decay_rate: `float` for weight decay.
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in classic
momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if any
of the string appears in a variable's name, the variable will be excluded
for computing weight decay. For example, one could specify the list like
['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
"""
name: str = "LARS"
momentum: float = 0.9
eeta: float = 0.001
weight_decay_rate: float = 0.0
nesterov: bool = False
classic_momentum: bool = True
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exponential moving average optimizer."""
from typing import Text, List
import tensorflow as tf
# pylint: disable=protected-access
class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
"""Optimizer that computes an exponential moving average of the variables.
Empirically it has been found that using the moving average of the trained
parameters of a deep network is better than using its trained parameters
directly. This optimizer allows you to compute this moving average and swap
the variables at save time so that any code outside of the training loop
will use by default the average values instead of the original ones.
Example of usage for training:
```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = ExponentialMovingAverage(opt)
opt.shadow_copy(model)
```
At test time, swap the shadow variables to evaluate on the averaged weights:
```python
opt.swap_weights()
# Test eval the model here
opt.swap_weights()
```
"""
def __init__(self,
optimizer: tf.keras.optimizers.Optimizer,
average_decay: float = 0.99,
start_step: int = 0,
dynamic_decay: bool = True,
name: Text = 'ExponentialMovingAverage',
**kwargs):
"""Construct a new ExponentialMovingAverage optimizer.
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
of optimizer updates. Decay will start at 0.1 and gradually increase
up to `average_decay` after each optimizer update. This behavior is
similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying
gradients. Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}.
"""
super().__init__(name, **kwargs)
self._average_decay = average_decay
self._start_step = tf.constant(start_step, tf.float32)
self._dynamic_decay = dynamic_decay
self._optimizer = optimizer
self._track_trackable(self._optimizer, 'base_optimizer')
def shadow_copy(self, model: tf.keras.Model):
"""Creates shadow variables for the given model weights."""
for var in model.weights:
self.add_slot(var, 'average', initializer='zeros')
self._average_weights = [
self.get_slot(var, 'average') for var in model.weights
]
self._model_weights = model.weights
@property
def has_shadow_copy(self):
"""Whether this optimizer has created shadow variables."""
return self._model_weights is not None
def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
def apply_gradients(self, grads_and_vars, name: Text = None):
result = self._optimizer.apply_gradients(grads_and_vars, name)
self.update_average(self.iterations)
return result
@tf.function
def update_average(self, step: tf.Tensor):
step = tf.cast(step, tf.float32)
if step < self._start_step:
decay = tf.constant(0., tf.float32)
elif self._dynamic_decay:
decay = step - self._start_step
decay = tf.minimum(self._average_decay, (1. + decay) / (10. + decay))
else:
decay = self._average_decay
def _apply_moving(v_moving, v_normal):
diff = v_moving - v_normal
v_moving.assign_sub(tf.cast(1. - decay, v_moving.dtype) * diff)
return v_moving
def _update(strategy, v_moving_and_v_normal):
for v_moving, v_normal in v_moving_and_v_normal:
strategy.extended.update(v_moving, _apply_moving, args=(v_normal,))
ctx = tf.distribute.get_replica_context()
return ctx.merge_call(_update, args=(zip(self._average_weights,
self._model_weights),))
def swap_weights(self):
"""Swap the average and moving weights.
This is a convenience method to allow one to evaluate the averaged weights
at test time. Loads the weights stored in `self._average` into the model,
keeping a copy of the original model weights. Swapping twice will return
the original weights.
"""
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_strategy()
strategy.run(self._swap_weights, args=())
else:
raise ValueError('Swapping weights must occur under a '
'tf.distribute.Strategy')
@tf.function
def _swap_weights(self):
def fn_0(a, b):
a.assign_add(b)
return a
def fn_1(b, a):
b.assign(a - b)
return b
def fn_2(a, b):
a.assign_sub(b)
return a
def swap(strategy, a_and_b):
"""Swap `a` and `b` and mirror to all devices."""
for a, b in a_and_b:
strategy.extended.update(a, fn_0, args=(b,)) # a = a + b
strategy.extended.update(b, fn_1, args=(a,)) # b = a - b
strategy.extended.update(a, fn_2, args=(b,)) # a = a - b
ctx = tf.distribute.get_replica_context()
return ctx.merge_call(
swap, args=(zip(self._average_weights, self._model_weights),))
def assign_average_vars(self, var_list: List[tf.Variable]):
"""Assign variables in var_list with their respective averages.
Args:
var_list: List of model variables to be assigned to their average.
Returns:
assign_op: The op corresponding to the assignment operation of
variables to their average.
"""
assign_op = tf.group([
var.assign(self.get_slot(var, 'average')) for var in var_list
if var.trainable
])
return assign_op
def _create_hypers(self):
self._optimizer._create_hypers() # pylint: disable=protected-access
def _prepare(self, var_list):
return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
@property
def iterations(self):
return self._optimizer.iterations
@iterations.setter
def iterations(self, variable):
self._optimizer.iterations = variable
@property
def weights(self):
# return self._weights + self._optimizer.weights
return self._optimizer.weights
def variables(self):
return self._weights + [self.iterations]
@property
def lr(self):
return self._optimizer._get_hyper('learning_rate')
@lr.setter
def lr(self, lr):
self._optimizer._set_hyper('learning_rate', lr)
@property
def learning_rate(self):
return self._optimizer._get_hyper('learning_rate')
@learning_rate.setter
def learning_rate(self, learning_rate): # pylint: disable=redefined-outer-name
self._optimizer._set_hyper('learning_rate', learning_rate)
def _resource_apply_dense(self, grad, var):
return self._optimizer._resource_apply_dense(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
return self._optimizer._resource_apply_sparse(grad, var, indices)
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
return self._optimizer._resource_apply_sparse_duplicate_indices(
grad, var, indices)
def get_config(self):
config = {
'optimizer': tf.keras.optimizers.serialize(self._optimizer),
'average_decay': self._average_decay,
'start_step': self._start_step,
'dynamic_decay': self._dynamic_decay,
}
base_config = super(ExponentialMovingAverage, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop('optimizer'),
custom_objects=custom_objects,
)
return cls(optimizer, **config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layer-wise adaptive rate scaling optimizer."""
import re
from typing import Text, List, Optional
import tensorflow as tf
# pylint: disable=protected-access
class LARS(tf.keras.optimizers.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def __init__(self,
learning_rate: float = 0.01,
momentum: float = 0.9,
weight_decay_rate: float = 0.0,
eeta: float = 0.001,
nesterov: bool = False,
classic_momentum: bool = True,
exclude_from_weight_decay: Optional[List[Text]] = None,
exclude_from_layer_adaptation: Optional[List[Text]] = None,
name: Text = "LARS",
**kwargs):
"""Constructs a LARSOptimizer.
Args:
learning_rate: `float` for learning rate. Defaults to 0.01.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent
in the relevant direction and dampens oscillations. Defaults to 0.9.
weight_decay_rate: `float` for weight decay.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the
highest scaling factor in LARS..
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in
classic momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
name: `Text` as optional name for the operations created when applying
gradients. Defaults to "LARS".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for
backward compatibility, recommended to use `learning_rate` instead.
"""
super(LARS, self).__init__(name, **kwargs)
self._set_hyper("learning_rate", learning_rate)
self._set_hyper("decay", self._initial_decay)
self.momentum = momentum
self.weight_decay_rate = weight_decay_rate
self.eeta = eeta
self.nesterov = nesterov
self.classic_momentum = classic_momentum
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
def _create_slots(self, var_list):
for v in var_list:
self.add_slot(v, "momentum")
def _resource_apply_dense(self, grad, param, apply_state=None):
if grad is None or param is None:
return tf.no_op()
var_device, var_dtype = param.device, param.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
learning_rate = coefficients["lr_t"]
param_name = param.name
v = self.get_slot(param, "momentum")
if self._use_weight_decay(param_name):
grad += self.weight_decay_rate * param
if self.classic_momentum:
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
g_norm = tf.norm(grad, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0),
1.0)
scaled_lr = learning_rate * trust_ratio
next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
if self.nesterov:
update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
else:
update = next_v
next_param = param - update
else:
next_v = tf.multiply(self.momentum, v) + grad
if self.nesterov:
update = tf.multiply(self.momentum, next_v) + grad
else:
update = next_v
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
v_norm = tf.norm(update, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 1.0),
1.0)
scaled_lr = trust_ratio * learning_rate
next_param = param - scaled_lr * update
return tf.group(*[
param.assign(next_param, use_locking=False),
v.assign(next_v, use_locking=False)
])
def _resource_apply_sparse(self, grad, handle, indices, apply_state):
raise NotImplementedError("Applying sparse gradients is not implemented.")
def _use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay_rate:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True
def get_config(self):
config = super(LARS, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"momentum": self.momentum,
"classic_momentum": self.classic_momentum,
"weight_decay_rate": self.weight_decay_rate,
"eeta": self.eeta,
"nesterov": self.nesterov,
})
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Learning rate schedule classes."""
from typing import Mapping, Any, Union, Optional
......@@ -41,12 +40,11 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
Args:
after_warmup_lr_sched: tf.keras.optimizers.schedules
.LearningRateSchedule or a constant.
warmup_steps: int. number of the warmup steps.
warmup_learning_rate: floating point number. Initial learning rate for the
warmup.
warmup_steps: Number of the warmup steps.
warmup_learning_rate: Initial learning rate for the warmup.
name: Optional, name of warmup schedule.
"""
super(LinearWarmup, self).__init__()
super().__init__()
self._name = name
self._after_warmup_lr_sched = after_warmup_lr_sched
self._warmup_steps = warmup_steps
......@@ -103,7 +101,7 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
warmup_steps: int,
power: float = 1.0,
name: str = "PolynomialWarmup"):
super(PolynomialWarmUp, self).__init__()
super().__init__()
if isinstance(after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
......@@ -122,7 +120,14 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
if self._warmup_steps <= 0:
warmup_percent_done = 1.0
else:
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero = tf.math.maximum(global_step_float, 1.0)
warmup_percent_done = step_non_zero / warmup_steps_float
warmup_learning_rate = (
self._initial_learning_rate *
tf.math.pow(warmup_percent_done, self._power))
......@@ -148,8 +153,154 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
config.update({
"warmup_steps": self._warmup_setps,
"warmup_steps": self._warmup_steps,
"power": self._power,
"name": self._name
})
return config
class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule follows lr * (step)^power."""
def __init__(self,
initial_learning_rate: float,
power: float = 1.0,
name: str = "DirectPowerDecay"):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: The initial learning rate.
power: The order of the polynomial.
name: Optional, name of warmup schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
self._power = power
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "DirectPowerDecay"):
step = tf.cast(step, tf.float32)
learning_rate = self._initial_learning_rate
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero = tf.math.maximum(step, 1.0)
learning_rate *= tf.math.pow(step_non_zero, self._power)
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
"initial_learning_rate": self._initial_learning_rate,
"power": self._power,
"name": self._name,
}
class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule with multiplied by linear decay at the end.
follows lr * (step)^power for the first total_decay_steps *
(1 - linear_decay_fraction) steps, and follows lr * (step)^power *
(total_decay_steps - step) / (total_decay_steps * linear_decay_fraction)
for the rest of the steps.
"""
def __init__(self,
initial_learning_rate: float,
total_decay_steps: int,
power: float = 1.0,
linear_decay_fraction: float = 0.1,
name: str = "PowerAndLinearDecay"):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: The initial learning rate.
total_decay_steps: The total number of steps for power + linear decay.
power: The order of the polynomial.
linear_decay_fraction: In the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay.
name: Optional, name of warmup schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
self._total_decay_steps = total_decay_steps
self._power = power
self._linear_decay_fraction = linear_decay_fraction
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PowerAndLinearDecay"):
step = tf.cast(step, tf.float32)
learning_rate = self._initial_learning_rate
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero = tf.math.maximum(step, 1.0)
learning_rate *= tf.math.pow(step_non_zero, self._power)
if self._total_decay_steps * self._linear_decay_fraction > 0:
learning_rate *= tf.minimum(
1.0, (self._total_decay_steps - step) /
(self._total_decay_steps * self._linear_decay_fraction))
learning_rate = tf.maximum(0.0, learning_rate)
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
"initial_learning_rate": self._initial_learning_rate,
"total_decay_steps": self._total_decay_steps,
"power": self._power,
"linear_decay_fraction": self._linear_decay_fraction,
"name": self._name,
}
class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Power learning rate decay with offset.
Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
Otherwise, learning rate equals to lr * (step - offset)^power.
"""
def __init__(self,
initial_learning_rate: float,
power: float = 1.0,
offset: int = 0,
pre_offset_learning_rate: float = 1.0e6,
name: str = "PowerDecayWithOffset"):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: The initial learning rate.
power: The order of the polynomial.
offset: The offset when computing the power decay.
pre_offset_learning_rate: The maximum learning rate we'll use.
name: Optional, name of warmup schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
self._power = power
self._offset = offset
self._pre_offset_lr = pre_offset_learning_rate
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PowerDecayWithOffset"):
step = tf.cast(step, tf.float32)
lr_after_offset = tf.math.pow(
tf.math.maximum(step - self._offset, 1.0), self._power) * (
self._initial_learning_rate)
sign = tf.cast(step > self._offset, tf.float32)
lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
# Power may give infinitely large LR. So cap it with pre_offset_lr.
return tf.math.minimum(lr_combined, self._pre_offset_lr)
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
"initial_learning_rate": self._initial_learning_rate,
"power": self._power,
"offset": self._offset,
"pre_offset_learning_rate": self._pre_offset_lr,
"name": self._name,
}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer factory class."""
from typing import Union
from typing import Callable, Union
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lars_optimizer
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
from official.nlp import optimization as nlp_optimization
......@@ -29,14 +30,19 @@ OPTIMIZERS_CLS = {
'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop
'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad,
}
LR_CLS = {
'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
'cosine': tf.keras.experimental.CosineDecay
'cosine': tf.keras.experimental.CosineDecay,
'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset,
}
WARMUP_CLS = {
......@@ -45,7 +51,7 @@ WARMUP_CLS = {
}
class OptimizerFactory(object):
class OptimizerFactory:
"""Optimizer factory class.
This class builds learning rate and optimizer based on an optimization config.
......@@ -88,7 +94,10 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type
if self._optimizer_type is None:
self._use_ema = config.ema is not None
self._ema_config = config.ema
if self._optimizer_config is None:
raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get()
......@@ -121,9 +130,12 @@ class OptimizerFactory(object):
return lr
@gin.configurable
def build_optimizer(
self, lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule,
float]):
self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
postprocessor: Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer] = None):
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
......@@ -131,15 +143,33 @@ class OptimizerFactory(object):
rate built using self.build_lr() is passed as an argument to this method.
Args:
lr: A floating point value, or
a tf.keras.optimizers.schedules.LearningRateSchedule instance.
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
Returns:
tf.keras.optimizers.Optimizer instance.
"""
optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm and clipvalue if None
if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue']
optimizer_dict['learning_rate'] = lr
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
return optimizer
if self._use_ema:
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
if postprocessor:
optimizer = postprocessor(optimizer)
assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
'{}'.format(optimizer))
return optimizer
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_factory.py."""
"""Tests for optimizer_factory.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.modeling.optimization import optimizer_factory
......@@ -25,12 +23,8 @@ from official.modeling.optimization.configs import optimization_config
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
@parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'),
('lars'), ('adagrad'))
def test_optimizers(self, optimizer_type):
params = {
'optimizer': {
......@@ -50,26 +44,63 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
optimizer = opt_factory.build_optimizer(lr, postprocessor=lambda x: x)
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
@parameterized.parameters((None, None), (1.0, None), (None, 1.0))
def test_gradient_clipping(self, clipnorm, clipvalue):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'clipnorm': clipnorm,
'clipvalue': clipvalue
}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 1.0
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([2.0, 3.0])
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
optimizer.apply_gradients(grads_and_vars)
self.assertAllClose(np.array([0.9, 1.9]), var0.numpy())
if clipvalue is not None:
self.assertAllClose(np.array([2.0, 3.0]), var1.numpy())
elif clipnorm is not None:
self.assertAllClose(np.array([2.4452999, 3.1679497]), var1.numpy())
else:
self.assertAllClose(np.array([1.0, 1.0]), var1.numpy())
def test_missing_types(self):
params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}
}
with self.assertRaises(ValueError):
......@@ -80,22 +111,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}
expected_lr_step_values = [
[0, 0.1],
[5000, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.1], [5000, 0.1], [10000, 0.1],
[10001, 0.01], [20000, 0.01], [20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -107,28 +136,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.01
}
}
expected_lr_step_values = [
[0, 0.01],
[250, 0.055],
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5500, 0.1],
[10000, 0.1], [10001, 0.01], [20000, 0.01],
[20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -140,7 +169,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'exponential',
......@@ -170,7 +201,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
......@@ -194,7 +227,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'cosine',
......@@ -204,11 +239,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
expected_lr_step_values = [[0, 0.1],
[250, 0.08535534],
[500, 0.04999999],
[750, 0.01464466],
[1000, 0]]
expected_lr_step_values = [[0, 0.1], [250, 0.08535534], [500, 0.04999999],
[750, 0.01464466], [1000, 0]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -220,7 +252,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'constant',
......@@ -250,28 +284,52 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
},
'warmup': {
'type': 'polynomial',
'polynomial': {'warmup_steps': 500, 'power': 2.}
'polynomial': {
'warmup_steps': 500,
'power': 2.
}
}
expected_lr_step_values = [
[0, 0.0],
[250, 0.025],
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.0], [250, 0.025], [500, 0.1], [5500, 0.1],
[10000, 0.1], [10001, 0.01], [20000, 0.01],
[20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value, places=6)
def test_power_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'power',
'power': {
'initial_learning_rate': 1.0,
'power': -1.0
}
}
}
expected_lr_step_values = [[0, 1.0], [1, 1.0], [250, 1. / 250.]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -279,5 +337,59 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
def test_power_linear_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'power_linear',
'power_linear': {
'initial_learning_rate': 1.0,
'power': -1.0,
'linear_decay_fraction': 0.5,
'total_decay_steps': 100,
}
}
}
expected_lr_step_values = [[0, 1.0], [1, 1.0], [40, 1. / 40.],
[60, 1. / 60. * 0.8]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
def test_power_with_offset_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'power_with_offset',
'power_with_offset': {
'initial_learning_rate': 1.0,
'power': -1.0,
'offset': 10,
'pre_offset_learning_rate': 3.0,
}
}
}
expected_lr_step_values = [[1, 3.0], [10, 3.0], [20, 1. / 10.]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,45 +11,75 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
loss_scale='dynamic',
use_experimental_api=False):
"""Configures optimizer object with performance options."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
elif loss_scale == 'dynamic':
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else:
# loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=False):
"""Sets mix precision policy."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
'when creating the LossScaleOptimizer instead.'
)
if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
else:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
else:
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
elif dtype == tf.float32:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
tf.keras.mixed_precision.set_global_policy('float32')
else:
raise ValueError('Unexpected dtype: %s' % dtype)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import abc
from typing import Any, Mapping
from absl import logging
import dataclasses
import six
import tensorflow as tf
from official.modeling.hyperparams import base_config
from official.modeling.progressive import utils
@dataclasses.dataclass
class ProgressiveConfig(base_config.Config):
pass
@six.add_metaclass(abc.ABCMeta)
class ProgressivePolicy:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def __init__(self):
"""Initialize stage policy."""
self._cur_train_dataset = None
self._cur_eval_dataset = None
self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)
stage_id = 0
self._stage_id = tf.Variable(
stage_id,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None))
def compute_stage_id(self, global_step: int) -> int:
for stage_id in range(self.num_stages()):
global_step -= self.num_steps(stage_id)
if global_step < 0:
return stage_id
logging.error('Global step %d found no matching progressive stages. '
'Default to the last stage.', global_step)
return self.num_stages() - 1
@abc.abstractmethod
def num_stages(self) -> int:
"""Return the total number of progressive stages."""
pass
@abc.abstractmethod
def num_steps(self, stage_id: int) -> int:
"""Return the total number of steps in this stage."""
pass
@abc.abstractmethod
def get_model(self,
stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model:
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@abc.abstractmethod
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
"""Return optimizer for this stage."""
pass
@abc.abstractmethod
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return training Dataset for this stage."""
pass
@abc.abstractmethod
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return evaluation Dataset for this stage."""
pass
@property
def cur_model(self) -> tf.keras.Model:
return self._volatiles.model
@property
def cur_train_dataset(self) -> tf.data.Dataset:
if self._cur_train_dataset is None:
self._cur_train_dataset = self.get_train_dataset(self._stage_id.numpy())
return self._cur_train_dataset
@property
def cur_eval_dataset(self) -> tf.data.Dataset:
if self._cur_eval_dataset is None:
self._cur_eval_dataset = self.get_eval_dataset(self._stage_id.numpy())
return self._cur_eval_dataset
@property
def cur_optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._volatiles.optimizer
@property
def is_last_stage(self) -> bool:
stage_id = self._stage_id.numpy()
return stage_id >= self.num_stages() - 1
@property
def cur_checkpoint_items(self) -> Mapping[str, Any]:
return dict(stage_id=self._stage_id, volatiles=self._volatiles)
def is_stage_advancing(self, global_step: int) -> bool:
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
return old_stage_id != new_stage_id
def update_pt_stage(self, global_step: int, pass_old_model=True) -> None:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
logging.info('Switching stage from %d to %d', old_stage_id, new_stage_id)
# Update stage id.
self._stage_id.assign(new_stage_id)
# Update dataset function.
self._cur_train_dataset = None
self._cur_eval_dataset = None
# Update optimizer and model.
new_optimizer = self.get_optimizer(new_stage_id)
self._volatiles.reassign_trackable(optimizer=new_optimizer)
new_model = self.get_model(
new_stage_id, old_model=self.cur_model if pass_old_model else None)
self._volatiles.reassign_trackable(model=new_model)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM binary for the progressive trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.modeling.progressive import train_lib
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import train_lib as base_train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
logging.info('Running progressive trainer.')
trainer = prog_trainer_lib.ProgressiveTrainer(
params, task, ckpt_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return trainer.model, {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive train_lib."""
import os
from absl import flags
from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict
from official.modeling.progressive import policies
from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=ProgTaskConfig())
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
task = task_factory.get_task(experiment_config.task,
logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import os
from typing import Any, Optional
# Import libraries
from absl import logging
import dataclasses
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.progressive import policies
from official.modeling.progressive import utils
ExperimentConfig = config_definitions.ExperimentConfig
@dataclasses.dataclass
class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None
export_only_final_stage_ckpt: bool = True
@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models."""
def __init__(
self,
config: ExperimentConfig,
prog_task: base_task.Task, # also implemented ProgressivePolicy.
ckpt_dir: str = '',
train: bool = True,
evaluate: bool = True,
checkpoint_exporter: Any = None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._runtime_options = trainer_lib.get_runtime_options(config)
self._task = prog_task
# Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir)
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self._checkpoint_exporter = checkpoint_exporter
self._global_step = orbit.utils.create_global_step()
self._checkpoint = utils.CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step,
**self._task.cur_checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32)
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
if train:
orbit.StandardTrainer.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardTrainer.
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function))
if evaluate:
orbit.StandardEvaluator.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardEvaluator.
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function))
@property
def model(self):
return self._task.cur_model
@property
def optimizer(self):
return self._task.cur_optimizer
# override
@property
def train_dataset(self):
"""Overriding StandardTrainer.train_dataset."""
return self._task.cur_train_dataset
# override
@train_dataset.setter
def train_dataset(self, _):
raise SyntaxError('Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.')
# override
@property
def eval_dataset(self):
"""Overriding StandardEvaluator.eval_dataset."""
return self._task.cur_eval_dataset
# override
@eval_dataset.setter
def eval_dataset(self, _):
raise SyntaxError('Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.')
def train_loop_end(self):
"""See base class."""
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs['learning_rate'] = self.optimizer.learning_rate
self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
return logs
def _update_pt_stage_from_ckpt(self, ckpt_file):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if not ckpt_file:
return
ckpt = tf.train.Checkpoint(global_step=self.global_step)
ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if not self.config.trainer.export_checkpoint:
logging.info('Not exporting checkpoints.')
return
if not self._task.is_last_stage and (
self.config.trainer.export_only_final_stage_ckpt):
logging.info('Not exporting checkpoints until the last stage.')
return
global_step_np = self.global_step.numpy()
if self.config.trainer.export_checkpoint_interval is None:
step_interval = self.config.trainer.checkpoint_interval
else:
step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0 and (
global_step_np < self._config.trainer.train_steps):
logging.info('Not exporting checkpoints in global step: %d.',
global_step_np)
return
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if hasattr(self.model, 'checkpoint_items'):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
file_prefix = os.path.join(export_ckpt_dir,
'ckpt-{}'.format(global_step_np))
checkpoint.save(file_prefix=file_prefix)
logging.info('Checkpoints exported: %s.', file_prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment