Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,6 +19,7 @@ import dataclasses ...@@ -19,6 +19,7 @@ import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.privacy import configs as dp_configs
@dataclasses.dataclass @dataclasses.dataclass
...@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config): ...@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: hyperparams.Config = None model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = () task_routines: Tuple[TaskRoutine, ...] = ()
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass @dataclasses.dataclass
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,9 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -31,7 +31,9 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
multi_task: multitask.MultiTask, multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model, multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel], base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer, optimizer: Union[tf.optimizers.Optimizer,
tf.keras.optimizers.experimental.Optimizer,
tf.keras.optimizers.legacy.Optimizer],
task_sampler: sampler.TaskSampler, task_sampler: sampler.TaskSampler,
trainer_options=None): trainer_options=None):
super().__init__( super().__init__(
...@@ -69,6 +71,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -69,6 +71,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
name: orbit.utils.create_global_step() for name in self.multi_task.tasks name: orbit.utils.create_global_step() for name in self.multi_task.tasks
} }
# If the new Keras optimizer is used, we require all model variables are
# created before the training and let the optimizer to create the slot
# variable all together.
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
multi_task_model.build()
optimizer.build(multi_task_model.trainable_variables)
def task_step_counter(self, name): def task_step_counter(self, name):
return self._task_step_counters[name] return self._task_step_counters[name]
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,9 +23,11 @@ from official.core import task_factory ...@@ -23,9 +23,11 @@ from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.privacy import configs as dp_configs
OptimizationConfig = optimization.OptimizationConfig OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
DifferentialPrivacyConfig = dp_configs.DifferentialPrivacyConfig
class MultiTask(tf.Module, metaclass=abc.ABCMeta): class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...@@ -93,9 +95,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -93,9 +95,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
@classmethod @classmethod
def create_optimizer(cls, def create_optimizer(cls,
optimizer_config: OptimizationConfig, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None): runtime_config: Optional[RuntimeConfig] = None,
dp_config: Optional[DifferentialPrivacyConfig] = None):
return base_task.Task.create_optimizer( return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config) optimizer_config=optimizer_config, runtime_config=runtime_config,
dp_config=dp_config)
def joint_train_step(self, task_inputs, def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel, multi_task_model: base_model.MultiTaskBaseModel,
...@@ -134,10 +138,10 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -134,10 +138,10 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self.tasks[name].process_metrics(task_metrics[name], labels, outputs, self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
**kwargs) **kwargs)
# Scales loss as the default gradients allreduce performs sum inside # Scales loss as the default gradients allreduce performs sum inside
# the optimizer. # the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy( scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync ).num_replicas_in_sync
tvars = multi_task_model.trainable_variables tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,6 +28,8 @@ class MockFooModel(tf.keras.Model): ...@@ -28,6 +28,8 @@ class MockFooModel(tf.keras.Model):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._share_layer = shared_layer self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1) self._foo_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"foo": tf.keras.Input(shape=(2,), dtype=tf.float32),
"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs): def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32)) self.add_loss(tf.zeros((1,), dtype=tf.float32))
...@@ -39,11 +41,13 @@ class MockFooModel(tf.keras.Model): ...@@ -39,11 +41,13 @@ class MockFooModel(tf.keras.Model):
class MockBarModel(tf.keras.Model): class MockBarModel(tf.keras.Model):
"""A mock model can only consume 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs): def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._share_layer = shared_layer self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1) self._bar_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs): def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32)) self.add_loss(tf.zeros((2,), dtype=tf.float32))
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, List, Optional, Tuple from typing import Any, List, Mapping, Optional, Tuple, Union
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -44,8 +44,12 @@ def run_experiment( ...@@ -44,8 +44,12 @@ def run_experiment(
mode: str, mode: str,
params: configs.MultiTaskExperimentConfig, params: configs.MultiTaskExperimentConfig,
model_dir: str, model_dir: str,
trainer: base_trainer.MultiTaskBaseTrainer = None run_post_eval: bool = False,
) -> base_model.MultiTaskBaseModel: trainer: base_trainer.MultiTaskBaseTrainer = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter
) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
Mapping[Any, Any]]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -56,8 +60,11 @@ def run_experiment( ...@@ -56,8 +60,11 @@ def run_experiment(
or 'continuous_eval'. or 'continuous_eval'.
params: ExperimentConfig instance. params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
trainer: (optional) A multi-task trainer to use. If none is provided, a trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`. default one will be created based on `params`.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns: Returns:
model: `base_model.MultiTaskBaseModel` instance. model: `base_model.MultiTaskBaseModel` instance.
...@@ -66,8 +73,7 @@ def run_experiment( ...@@ -66,8 +73,7 @@ def run_experiment(
is_training = 'train' in mode is_training = 'train' in mode
is_eval = 'eval' in mode is_eval = 'eval' in mode
with distribution_strategy.scope(): with distribution_strategy.scope():
optimizer = task.create_optimizer(params.trainer.optimizer_config, optimizer = train_utils.create_optimizer(task, params)
params.runtime)
kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer) kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
if params.trainer.trainer_type == 'interleaving': if params.trainer.trainer_type == 'interleaving':
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
...@@ -83,8 +89,7 @@ def run_experiment( ...@@ -83,8 +89,7 @@ def run_experiment(
model=model, model=model,
eval_steps=eval_steps, eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
params, model_dir))
else: else:
evaluator = None evaluator = None
...@@ -95,7 +100,6 @@ def run_experiment( ...@@ -95,7 +100,6 @@ def run_experiment(
checkpoint = evaluator.checkpoint checkpoint = evaluator.checkpoint
global_step = evaluator.global_step global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
directory=model_dir, directory=model_dir,
...@@ -140,7 +144,11 @@ def run_experiment( ...@@ -140,7 +144,11 @@ def run_experiment(
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
return model if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model
def run_experiment_with_multitask_eval( def run_experiment_with_multitask_eval(
...@@ -153,7 +161,10 @@ def run_experiment_with_multitask_eval( ...@@ -153,7 +161,10 @@ def run_experiment_with_multitask_eval(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]: trainer: Optional[core_lib.Trainer] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter,
) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -170,6 +181,7 @@ def run_experiment_with_multitask_eval( ...@@ -170,6 +181,7 @@ def run_experiment_with_multitask_eval(
trainer: the core_lib.Trainer instance. It should be created within the trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'. if `mode` contains 'train'.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns: Returns:
model: `tf.keras.Model` instance. model: `tf.keras.Model` instance.
...@@ -183,8 +195,7 @@ def run_experiment_with_multitask_eval( ...@@ -183,8 +195,7 @@ def run_experiment_with_multitask_eval(
config=params, config=params,
task=train_task, task=train_task,
model=train_task.build_model(), model=train_task.build_model(),
optimizer=train_task.create_optimizer(params.trainer.optimizer_config, optimizer=train_utils.create_optimizer(train_task, params),
params.runtime),
train=True, train=True,
evaluate=False) evaluate=False)
else: else:
...@@ -200,8 +211,7 @@ def run_experiment_with_multitask_eval( ...@@ -200,8 +211,7 @@ def run_experiment_with_multitask_eval(
model=model, model=model,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps, eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
params, model_dir))
else: else:
evaluator = None evaluator = None
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -58,8 +58,9 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -58,8 +58,9 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
mode='eager', mode='eager',
optimizer=['sgd_experimental', 'sgd'],
flag_mode=['train', 'eval', 'train_and_eval'])) flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode): def test_end_to_end(self, distribution_strategy, optimizer, flag_mode):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig( experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig( task=configs.MultiTaskConfig(
...@@ -70,6 +71,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -70,6 +71,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task_name='bar', task_config=test_utils.BarConfig())))) task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
experiment_config.trainer.optimizer_config.optimizer.type = optimizer
with distribution_strategy.scope(): with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task) test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel() model = test_utils.MockMultiTaskModel()
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -216,14 +216,14 @@ class StepCosineLrConfig(base_config.Config): ...@@ -216,14 +216,14 @@ class StepCosineLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay. """Configuration for stepwise learning rate decay.
This class is a container for the piecewise cosine learning rate scheduling This class is a container for the piecewise cosine learning rate scheduling
configs. It will configure an instance of StepConsineDecayWithOffset keras configs. It will configure an instance of StepCosineDecayWithOffset keras
learning rate schedule. learning rate schedule.
```python ```python
boundaries: [100000, 110000] boundaries: [100000, 110000]
values: [1.0, 0.5] values: [1.0, 0.5]
lr_decayed_fn = ( lr_decayed_fn = (
lr_schedule.StepConsineDecayWithOffset( lr_schedule.StepCosineDecayWithOffset(
boundaries, boundaries,
values)) values))
``` ```
...@@ -243,7 +243,7 @@ class StepCosineLrConfig(base_config.Config): ...@@ -243,7 +243,7 @@ class StepCosineLrConfig(base_config.Config):
[boundaries[n], end] -> values[n+1] to 0. [boundaries[n], end] -> values[n+1] to 0.
offset: An int. The offset applied to steps. Defaults to 0. offset: An int. The offset applied to steps. Defaults to 0.
""" """
name: str = 'StepConsineDecayWithOffset' name: str = 'StepCosineDecayWithOffset'
boundaries: Optional[List[int]] = None boundaries: Optional[List[int]] = None
values: Optional[List[float]] = None values: Optional[List[float]] = None
offset: int = 0 offset: int = 0
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -45,8 +45,14 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -45,8 +45,14 @@ class OptimizerConfig(oneof.OneOfConfig):
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
sgd_experimental: opt_cfg.SGDExperimentalConfig = (
opt_cfg.SGDExperimentalConfig())
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adam_experimental: opt_cfg.AdamExperimentalConfig = (
opt_cfg.AdamExperimentalConfig())
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
adamw_experimental: opt_cfg.AdamWeightDecayExperimentalConfig = (
opt_cfg.AdamWeightDecayExperimentalConfig())
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig() rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig() lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -54,6 +54,27 @@ class SGDConfig(BaseOptimizerConfig): ...@@ -54,6 +54,27 @@ class SGDConfig(BaseOptimizerConfig):
momentum: float = 0.0 momentum: float = 0.0
# TODO(b/216129465): Merge this config with SGDConfig after the experimental
# optimizer graduates.
@dataclasses.dataclass
class SGDExperimentalConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of
`tf.keras.optimizer.experimental.SGD`.
Attributes:
name: name of the optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
jit_compile: if True, jit compile will be used.
"""
name: str = "SGD"
nesterov: bool = False
momentum: float = 0.0
jit_compile: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class RMSPropConfig(BaseOptimizerConfig): class RMSPropConfig(BaseOptimizerConfig):
"""Configuration for RMSProp optimizer. """Configuration for RMSProp optimizer.
...@@ -115,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig): ...@@ -115,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig):
amsgrad: bool = False amsgrad: bool = False
@dataclasses.dataclass
class AdamExperimentalConfig(BaseOptimizerConfig):
"""Configuration for experimental Adam optimizer.
The attributes for this class matches the arguments of
`tf.keras.optimizer.experimental.Adam`.
Attributes:
name: name of the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
jit_compile: if True, jit compile will be used.
"""
name: str = "Adam"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
jit_compile: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamWeightDecayConfig(BaseOptimizerConfig): class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay. """Configuration for Adam optimizer with weight decay.
...@@ -145,6 +190,32 @@ class AdamWeightDecayConfig(BaseOptimizerConfig): ...@@ -145,6 +190,32 @@ class AdamWeightDecayConfig(BaseOptimizerConfig):
gradient_clip_norm: float = 1.0 gradient_clip_norm: float = 1.0
@dataclasses.dataclass
class AdamWeightDecayExperimentalConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay.
Attributes:
name: name of the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
weight_decay: float. Weight decay rate. Default to 0.
global_clipnorm: A positive float. Clips the gradients to this maximum
L2-norm. Default to 1.0.
jit_compile: if True, jit compile will be used.
"""
name: str = "AdamWeightDecayExperimental"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
weight_decay: float = 0.0
global_clipnorm: float = 1.0
jit_compile: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class LAMBConfig(BaseOptimizerConfig): class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer. """Configuration for LAMB optimizer.
...@@ -266,3 +337,5 @@ class AdafactorConfig(BaseOptimizerConfig): ...@@ -266,3 +337,5 @@ class AdafactorConfig(BaseOptimizerConfig):
min_dim_size_to_factor: int = 128 min_dim_size_to_factor: int = 128
epsilon1: float = 1e-30 epsilon1: float = 1e-30
epsilon2: float = 1e-3 epsilon2: float = 1e-3
weight_decay: Optional[float] = None
include_in_weight_decay: Optional[str] = None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
# pylint: disable=protected-access # pylint: disable=protected-access
class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): class ExponentialMovingAverage(tf.keras.optimizers.legacy.Optimizer):
"""Optimizer that computes an exponential moving average of the variables. """Optimizer that computes an exponential moving average of the variables.
Empirically it has been found that using the moving average of the trained Empirically it has been found that using the moving average of the trained
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
# pylint: disable=protected-access # pylint: disable=protected-access
class LARS(tf.keras.optimizers.Optimizer): class LARS(tf.keras.optimizers.legacy.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training. """Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You, Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment