Commit b035a227 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319058156
parent df89d3e0
...@@ -339,7 +339,8 @@ def train_and_eval( ...@@ -339,7 +339,8 @@ def train_and_eval(
optimizer = optimizer_factory.build_optimizer( optimizer = optimizer_factory.build_optimizer(
optimizer_name=params.model.optimizer.name, optimizer_name=params.model.optimizer.name,
base_learning_rate=learning_rate, base_learning_rate=learning_rate,
params=params.model.optimizer.as_dict()) params=params.model.optimizer.as_dict(),
model=model)
metrics_map = _get_metrics(one_hot) metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics] metrics = [metrics_map[metric] for metric in params.train.metrics]
......
...@@ -18,11 +18,12 @@ from __future__ import division ...@@ -18,11 +18,12 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
from typing import Any, Dict, Text, List
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from typing import Any, Dict, Text, List
from official.vision.image_classification import learning_rate from official.vision.image_classification import learning_rate
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
...@@ -250,7 +251,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer): ...@@ -250,7 +251,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
def build_optimizer( def build_optimizer(
optimizer_name: Text, optimizer_name: Text,
base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
params: Dict[Text, Any]): params: Dict[Text, Any],
model: tf.keras.Model = None):
"""Build the optimizer based on name. """Build the optimizer based on name.
Args: Args:
...@@ -261,6 +263,8 @@ def build_optimizer( ...@@ -261,6 +263,8 @@ def build_optimizer(
params: String -> Any dictionary representing the optimizer params. params: String -> Any dictionary representing the optimizer params.
This should contain optimizer specific parameters such as This should contain optimizer specific parameters such as
`base_learning_rate`, `decay`, etc. `base_learning_rate`, `decay`, etc.
model: The `tf.keras.Model`. This is used for the shadow copy if using
`MovingAverage`.
Returns: Returns:
A tf.keras.Optimizer. A tf.keras.Optimizer.
...@@ -322,10 +326,13 @@ def build_optimizer( ...@@ -322,10 +326,13 @@ def build_optimizer(
# Moving average should be applied last, as it's applied at test time # Moving average should be applied last, as it's applied at test time
moving_average_decay = params.get('moving_average_decay', 0.) moving_average_decay = params.get('moving_average_decay', 0.)
if moving_average_decay is not None and moving_average_decay > 0.: if moving_average_decay is not None and moving_average_decay > 0.:
if model is None:
raise ValueError('`model` must be provided if using `MovingAverage`.')
logging.info('Including moving average decay.') logging.info('Including moving average decay.')
optimizer = MovingAverage( optimizer = MovingAverage(
optimizer, optimizer=optimizer,
average_decay=moving_average_decay) average_decay=moving_average_decay)
optimizer.shadow_copy(model)
return optimizer return optimizer
......
...@@ -19,15 +19,21 @@ from __future__ import division ...@@ -19,15 +19,21 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import tensorflow as tf
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf
from official.vision.image_classification import optimizer_factory from official.vision.image_classification import optimizer_factory
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
def build_toy_model(self) -> tf.keras.Model:
"""Creates a toy `tf.Keras.Model`."""
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(1,)))
return model
@parameterized.named_parameters( @parameterized.named_parameters(
('sgd', 'sgd', 0., False), ('sgd', 'sgd', 0., False),
('momentum', 'momentum', 0., False), ('momentum', 'momentum', 0., False),
...@@ -40,6 +46,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -40,6 +46,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
('rmsprop_ema', 'rmsprop', 0.999, False)) ('rmsprop_ema', 'rmsprop', 0.999, False))
def test_optimizer(self, optimizer_name, moving_average_decay, lookahead): def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
"""Smoke test to be sure no syntax errors.""" """Smoke test to be sure no syntax errors."""
model = self.build_toy_model()
params = { params = {
'learning_rate': 0.001, 'learning_rate': 0.001,
'rho': 0.09, 'rho': 0.09,
...@@ -51,7 +58,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -51,7 +58,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
optimizer = optimizer_factory.build_optimizer( optimizer = optimizer_factory.build_optimizer(
optimizer_name=optimizer_name, optimizer_name=optimizer_name,
base_learning_rate=params['learning_rate'], base_learning_rate=params['learning_rate'],
params=params) params=params,
model=model)
self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer)) self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
def test_unknown_optimizer(self): def test_unknown_optimizer(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment