Commit eb6fa0b2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 305380592
parent 98074f7a
...@@ -20,12 +20,12 @@ from __future__ import division ...@@ -20,12 +20,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from typing import Any, List, MutableMapping, Text
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Any, List, MutableMapping
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import optimizer_factory
def get_callbacks(model_checkpoint: bool = True, def get_callbacks(model_checkpoint: bool = True,
...@@ -33,6 +33,7 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -33,6 +33,7 @@ def get_callbacks(model_checkpoint: bool = True,
time_history: bool = True, time_history: bool = True,
track_lr: bool = True, track_lr: bool = True,
write_model_weights: bool = True, write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0, initial_step: int = 0,
batch_size: int = 0, batch_size: int = 0,
log_steps: int = 0, log_steps: int = 0,
...@@ -42,8 +43,7 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -42,8 +43,7 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks = [] callbacks = []
if model_checkpoint: if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append( callbacks.append(tf.keras.callbacks.ModelCheckpoint(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1)) ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard: if include_tensorboard:
callbacks.append( callbacks.append(
...@@ -58,6 +58,17 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -58,6 +58,17 @@ def get_callbacks(model_checkpoint: bool = True,
batch_size, batch_size,
log_steps, log_steps,
logdir=model_dir if include_tensorboard else None)) logdir=model_dir if include_tensorboard else None))
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(
model_dir, 'average', 'model.ckpt-{epoch:04d}')
callbacks.append(AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks return callbacks
...@@ -136,7 +147,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -136,7 +147,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def _calculate_lr(self) -> int: def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step.""" """Calculates the learning rate given the current step."""
return get_scalar_from_tensor( return get_scalar_from_tensor(
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer: def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model.""" """Get the base optimizer used by the current model."""
...@@ -148,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -148,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
optimizer = optimizer._optimizer # pylint:disable=protected-access optimizer = optimizer._optimizer # pylint:disable=protected-access
return optimizer return optimizer
class MovingAverageCallback(tf.keras.callbacks.Callback):
"""A Callback to be used with a `MovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def __init__(self,
overwrite_weights_on_train_end: bool = False,
**kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimizer_factory.MovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: MutableMapping[Text, Any] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: MutableMapping[Text, Any] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
def __init__(
self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(
filepath,
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
raise TypeError(
'AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs)
self.model.set_weights(non_avg_weights)
return result
...@@ -360,8 +360,6 @@ def train_and_eval( ...@@ -360,8 +360,6 @@ def train_and_eval(
model_dir=params.model_dir, model_dir=params.model_dir,
train_steps=train_steps) train_steps=train_steps)
serialize_config(params=params, model_dir=params.model_dir)
# TODO(dankondratyuk): callbacks significantly slow down training
callbacks = custom_callbacks.get_callbacks( callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard, include_tensorboard=params.train.callbacks.enable_tensorboard,
...@@ -373,6 +371,8 @@ def train_and_eval( ...@@ -373,6 +371,8 @@ def train_and_eval(
log_steps=params.train.time_history.log_steps, log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir) model_dir=params.model_dir)
serialize_config(params=params, model_dir=params.model_dir)
if params.evaluation.skip_eval: if params.evaluation.skip_eval:
validation_kwargs = {} validation_kwargs = {}
else: else:
...@@ -388,7 +388,9 @@ def train_and_eval( ...@@ -388,7 +388,9 @@ def train_and_eval(
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
initial_epoch=initial_epoch, initial_epoch=initial_epoch,
callbacks=callbacks, callbacks=callbacks,
**validation_kwargs) **validation_kwargs,
experimental_steps_per_execution=params.train.steps_per_loop,
verbose=2)
validation_output = None validation_output = None
if not params.evaluation.skip_eval: if not params.evaluation.skip_eval:
......
...@@ -82,6 +82,8 @@ class TrainConfig(base_config.Config): ...@@ -82,6 +82,8 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig. callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig. metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig. tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
""" """
resume_checkpoint: bool = None resume_checkpoint: bool = None
...@@ -91,6 +93,7 @@ class TrainConfig(base_config.Config): ...@@ -91,6 +93,7 @@ class TrainConfig(base_config.Config):
metrics: MetricsConfig = None metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig() tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig() time_history: TimeHistoryConfig = TimeHistoryConfig()
steps_per_loop: int = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -176,6 +179,7 @@ class LearningRateConfig(base_config.Config): ...@@ -176,6 +179,7 @@ class LearningRateConfig(base_config.Config):
multipliers: multipliers used in piecewise constant decay with warmup. multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default). size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
""" """
name: str = None name: str = None
...@@ -187,6 +191,7 @@ class LearningRateConfig(base_config.Config): ...@@ -187,6 +191,7 @@ class LearningRateConfig(base_config.Config):
boundaries: List[int] = None boundaries: List[int] = None
multipliers: List[float] = None multipliers: List[float] = None
scale_by_batch_size: float = 0. scale_by_batch_size: float = 0.
staircase: bool = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -54,7 +54,8 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -54,7 +54,8 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
...@@ -86,7 +87,8 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -86,7 +87,8 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False),
steps_per_loop=1)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
......
...@@ -40,6 +40,8 @@ model: ...@@ -40,6 +40,8 @@ model:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
...@@ -39,7 +39,7 @@ model: ...@@ -39,7 +39,7 @@ model:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0. moving_average_decay: 0.0
lookahead: false lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
......
...@@ -33,6 +33,8 @@ model: ...@@ -33,6 +33,8 @@ model:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
...@@ -38,6 +38,8 @@ model: ...@@ -38,6 +38,8 @@ model:
name: 'rmsprop' name: 'rmsprop'
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate: learning_rate:
name: 'exponential' name: 'exponential'
loss: loss:
......
...@@ -19,15 +19,14 @@ from __future__ import division ...@@ -19,15 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tf1
from typing import Text, Optional from typing import Text, Optional
from tensorflow.python.tpu import tpu_function from tensorflow.python.tpu import tpu_function
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Vision')
class TpuBatchNormalization(tf.keras.layers.BatchNormalization): class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
"""Cross replica batch normalization.""" """Cross replica batch normalization."""
......
...@@ -72,4 +72,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig): ...@@ -72,4 +72,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
decay_epochs=2.4, decay_epochs=2.4,
decay_rate=0.97, decay_rate=0.97,
warmup_epochs=5, warmup_epochs=5,
scale_by_batch_size=1. / 128.) scale_by_batch_size=1. / 128.,
staircase=True)
...@@ -22,10 +22,230 @@ from absl import logging ...@@ -22,10 +22,230 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from typing import Any, Dict, Text from typing import Any, Dict, Text, List
from official.vision.image_classification import learning_rate from official.vision.image_classification import learning_rate
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
# pylint: disable=protected-access
class MovingAverage(tf.keras.optimizers.Optimizer):
"""Optimizer that computes a moving average of the variables.
Empirically it has been found that using the moving average of the trained
parameters of a deep network is better than using its trained parameters
directly. This optimizer allows you to compute this moving average and swap
the variables at save time so that any code outside of the training loop
will use by default the average values instead of the original ones.
Example of usage for training:
```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = MovingAverage(opt)
opt.shadow_copy(model)
```
At test time, swap the shadow variables to evaluate on the averaged weights:
```python
opt.swap_weights()
# Test eval the model here
opt.swap_weights()
```
"""
def __init__(self,
optimizer: tf.keras.optimizers.Optimizer,
average_decay: float = 0.99,
start_step: int = 0,
dynamic_decay: bool = True,
name: Text = 'moving_average',
**kwargs):
"""Construct a new MovingAverage optimizer.
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
of optimizer updates. Decay will start at 0.1 and gradually increase
up to `average_decay` after each optimizer update. This behavior is
similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying
gradients. Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}.
"""
super(MovingAverage, self).__init__(name, **kwargs)
self._optimizer = optimizer
self._average_decay = average_decay
self._start_step = tf.constant(start_step, tf.float32)
self._dynamic_decay = dynamic_decay
def shadow_copy(self, model: tf.keras.Model):
"""Creates shadow variables for the given model weights."""
for var in model.weights:
self.add_slot(var, 'average', initializer='zeros')
self._average_weights = [
self.get_slot(var, 'average') for var in model.weights
]
self._model_weights = model.weights
@property
def has_shadow_copy(self):
"""Whether this optimizer has created shadow variables."""
return self._model_weights is not None
def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
def apply_gradients(self, grads_and_vars, name: Text = None):
result = self._optimizer.apply_gradients(grads_and_vars, name)
self.update_average(self._optimizer.iterations)
return result
@tf.function
def update_average(self, step: tf.Tensor):
step = tf.cast(step, tf.float32)
if step < self._start_step:
decay = tf.constant(0., tf.float32)
elif self._dynamic_decay:
decay = step - self._start_step
decay = tf.minimum(self._average_decay, (1. + decay) / (10. + decay))
else:
decay = self._average_decay
def _apply_moving(v_moving, v_normal):
diff = v_moving - v_normal
v_moving.assign_sub(tf.cast(1. - decay, v_moving.dtype) * diff)
return v_moving
def _update(strategy, v_moving_and_v_normal):
for v_moving, v_normal in v_moving_and_v_normal:
strategy.extended.update(v_moving, _apply_moving, args=(v_normal,))
ctx = tf.distribute.get_replica_context()
return ctx.merge_call(_update, args=(zip(self._average_weights,
self._model_weights),))
def swap_weights(self):
"""Swap the average and moving weights.
This is a convenience method to allow one to evaluate the averaged weights
at test time. Loads the weights stored in `self._average` into the model,
keeping a copy of the original model weights. Swapping twice will return
the original weights.
"""
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_strategy()
strategy.run(self._swap_weights, args=())
else:
raise ValueError('Swapping weights must occur under a '
'tf.distribute.Strategy')
@tf.function
def _swap_weights(self):
def fn_0(a, b):
a.assign_add(b)
return a
def fn_1(b, a):
b.assign(a - b)
return b
def fn_2(a, b):
a.assign_sub(b)
return a
def swap(strategy, a_and_b):
"""Swap `a` and `b` and mirror to all devices."""
for a, b in a_and_b:
strategy.extended.update(a, fn_0, args=(b,)) # a = a + b
strategy.extended.update(b, fn_1, args=(a,)) # b = a - b
strategy.extended.update(a, fn_2, args=(b,)) # a = a - b
ctx = tf.distribute.get_replica_context()
return ctx.merge_call(
swap, args=(zip(self._average_weights, self._model_weights),))
def assign_average_vars(self, var_list: List[tf.Variable]):
"""Assign variables in var_list with their respective averages.
Args:
var_list: List of model variables to be assigned to their average.
Returns:
assign_op: The op corresponding to the assignment operation of
variables to their average.
"""
assign_op = tf.group([
var.assign(self.get_slot(var, 'average')) for var in var_list
if var.trainable
])
return assign_op
def _create_hypers(self):
self._optimizer._create_hypers() # pylint: disable=protected-access
def _prepare(self, var_list):
return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
@property
def iterations(self):
return self._optimizer.iterations
@iterations.setter
def iterations(self, variable):
self._optimizer.iterations = variable
@property
def weights(self):
# return self._weights + self._optimizer.weights
return self._optimizer.weights
@property
def lr(self):
return self._optimizer._get_hyper('learning_rate')
@lr.setter
def lr(self, lr):
self._optimizer._set_hyper('learning_rate', lr)
@property
def learning_rate(self):
return self._optimizer._get_hyper('learning_rate')
@learning_rate.setter
def learning_rate(self, learning_rate): # pylint: disable=redefined-outer-name
self._optimizer._set_hyper('learning_rate', learning_rate)
def _resource_apply_dense(self, grad, var):
return self._optimizer._resource_apply_dense(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
return self._optimizer._resource_apply_sparse(grad, var, indices)
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
return self._optimizer._resource_apply_sparse_duplicate_indices(
grad, var, indices)
def get_config(self):
config = {
'optimizer': tf.keras.optimizers.serialize(self._optimizer),
'average_decay': self._average_decay,
'start_step': self._start_step,
'dynamic_decay': self._dynamic_decay,
}
base_config = super(MovingAverage, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop('optimizer'),
custom_objects=custom_objects,
)
return cls(optimizer, **config)
def build_optimizer( def build_optimizer(
optimizer_name: Text, optimizer_name: Text,
...@@ -95,16 +315,17 @@ def build_optimizer( ...@@ -95,16 +315,17 @@ def build_optimizer(
else: else:
raise ValueError('Unknown optimizer %s' % optimizer_name) raise ValueError('Unknown optimizer %s' % optimizer_name)
if params.get('lookahead', None):
logging.info('Using lookahead optimizer.')
optimizer = tfa.optimizers.Lookahead(optimizer)
# Moving average should be applied last, as it's applied at test time
moving_average_decay = params.get('moving_average_decay', 0.) moving_average_decay = params.get('moving_average_decay', 0.)
if moving_average_decay is not None and moving_average_decay > 0.: if moving_average_decay is not None and moving_average_decay > 0.:
logging.info('Including moving average decay.') logging.info('Including moving average decay.')
optimizer = tfa.optimizers.MovingAverage( optimizer = MovingAverage(
optimizer, optimizer,
average_decay=params['moving_average_decay'], average_decay=moving_average_decay)
num_updates=None)
if params.get('lookahead', None):
logging.info('Using lookahead optimizer.')
optimizer = tfa.optimizers.Lookahead(optimizer)
return optimizer return optimizer
...@@ -139,7 +360,8 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -139,7 +360,8 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
lr = tf.keras.optimizers.schedules.ExponentialDecay( lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=base_lr, initial_learning_rate=base_lr,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_rate=decay_rate) decay_rate=decay_rate,
staircase=params.staircase)
elif decay_type == 'piecewise_constant_with_warmup': elif decay_type == 'piecewise_constant_with_warmup':
logging.info('Using Piecewise constant decay with warmup. ' logging.info('Using Piecewise constant decay with warmup. '
'Parameters: batch_size: %d, epoch_size: %d, ' 'Parameters: batch_size: %d, epoch_size: %d, '
......
...@@ -35,9 +35,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,9 +35,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
('adam', 'adam', 0., False), ('adam', 'adam', 0., False),
('adamw', 'adamw', 0., False), ('adamw', 'adamw', 0., False),
('momentum_lookahead', 'momentum', 0., True), ('momentum_lookahead', 'momentum', 0., True),
('sgd_ema', 'sgd', 0.001, False), ('sgd_ema', 'sgd', 0.999, False),
('momentum_ema', 'momentum', 0.001, False), ('momentum_ema', 'momentum', 0.999, False),
('rmsprop_ema', 'rmsprop', 0.001, False)) ('rmsprop_ema', 'rmsprop', 0.999, False))
def test_optimizer(self, optimizer_name, moving_average_decay, lookahead): def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
"""Smoke test to be sure no syntax errors.""" """Smoke test to be sure no syntax errors."""
params = { params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment