Commit 852e098a authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Enable ExponentialMovingAverage (EMA) in ODAPI TF2 training and evaluation.

PiperOrigin-RevId: 358070404
parent d7ce106b
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from official.modeling.optimization import ema_optimizer
# pylint: enable=g-import-not-at-top
try: try:
from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top
...@@ -130,7 +136,9 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None): ...@@ -130,7 +136,9 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None):
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average: if optimizer_config.use_moving_average:
raise ValueError('Moving average not supported in eager mode.') optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer=optimizer,
average_decay=optimizer_config.moving_average_decay)
return optimizer, summary_vars return optimizer, summary_vars
......
...@@ -82,7 +82,7 @@ class OptimizerBuilderV2Test(tf.test.TestCase): ...@@ -82,7 +82,7 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam) self.assertIsInstance(optimizer, tf.keras.optimizers.Adam)
def testMovingAverageOptimizerUnsupported(self): def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
adam_optimizer: { adam_optimizer: {
learning_rate: { learning_rate: {
...@@ -95,8 +95,8 @@ class OptimizerBuilderV2Test(tf.test.TestCase): ...@@ -95,8 +95,8 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
""" """
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError): optimizer, _ = optimizer_builder.build(optimizer_proto)
optimizer_builder.build(optimizer_proto) self.assertIsInstance(optimizer, tf.keras.optimizers.Optimizer)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -514,6 +514,13 @@ def train_loop( ...@@ -514,6 +514,13 @@ def train_loop(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
def train_dataset_fn(input_context): def train_dataset_fn(input_context):
"""Callable to create train input.""" """Callable to create train input."""
...@@ -536,6 +543,8 @@ def train_loop( ...@@ -536,6 +543,8 @@ def train_loop(
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build( optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step) train_config.optimizer, global_step=global_step)
if train_config.optimizer.use_moving_average:
optimizer.shadow_copy(detection_model)
if callable(learning_rate): if callable(learning_rate):
learning_rate_fn = learning_rate learning_rate_fn = learning_rate
...@@ -1057,6 +1066,13 @@ def eval_continuously( ...@@ -1057,6 +1066,13 @@ def eval_continuously(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset( eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input( inputs.eval_input(
...@@ -1068,13 +1084,22 @@ def eval_continuously( ...@@ -1068,13 +1084,22 @@ def eval_continuously(
global_step = tf.compat.v2.Variable( global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64) 0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
optimizer, _ = optimizer_builder.build(
configs['train_config'].optimizer, global_step=global_step)
for latest_checkpoint in tf.train.checkpoints_iterator( for latest_checkpoint in tf.train.checkpoints_iterator(
checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval): checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
ckpt = tf.compat.v2.train.Checkpoint( ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model) step=global_step, model=detection_model, optimizer=optimizer)
if eval_config.use_moving_averages:
optimizer.shadow_copy(detection_model)
ckpt.restore(latest_checkpoint).expect_partial() ckpt.restore(latest_checkpoint).expect_partial()
if eval_config.use_moving_averages:
optimizer.swap_weights()
summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer = tf.compat.v2.summary.create_file_writer(
os.path.join(model_dir, 'eval', eval_input_config.name)) os.path.join(model_dir, 'eval', eval_input_config.name))
with summary_writer.as_default(): with summary_writer.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment