"src/graph/vscode:/vscode.git/clone" did not exist on "8a83027274812ebc7f6613f0f418a81d0b2ba843"
Commit 852e098a authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Enable ExponentialMovingAverage (EMA) in ODAPI TF2 training and evaluation.

PiperOrigin-RevId: 358070404
parent d7ce106b
......@@ -18,6 +18,12 @@
import tensorflow.compat.v1 as tf
from object_detection.utils import learning_schedules
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from official.modeling.optimization import ema_optimizer
# pylint: enable=g-import-not-at-top
try:
from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top
......@@ -130,7 +136,9 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None):
raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average:
raise ValueError('Moving average not supported in eager mode.')
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer=optimizer,
average_decay=optimizer_config.moving_average_decay)
return optimizer, summary_vars
......
......@@ -82,7 +82,7 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam)
def testMovingAverageOptimizerUnsupported(self):
def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
......@@ -95,8 +95,8 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Optimizer)
if __name__ == '__main__':
......
......@@ -514,6 +514,13 @@ def train_loop(
with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
def train_dataset_fn(input_context):
"""Callable to create train input."""
......@@ -536,6 +543,8 @@ def train_loop(
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step)
if train_config.optimizer.use_moving_average:
optimizer.shadow_copy(detection_model)
if callable(learning_rate):
learning_rate_fn = learning_rate
......@@ -1057,6 +1066,13 @@ def eval_continuously(
with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input(
......@@ -1068,13 +1084,22 @@ def eval_continuously(
global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
optimizer, _ = optimizer_builder.build(
configs['train_config'].optimizer, global_step=global_step)
for latest_checkpoint in tf.train.checkpoints_iterator(
checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model)
step=global_step, model=detection_model, optimizer=optimizer)
if eval_config.use_moving_averages:
optimizer.shadow_copy(detection_model)
ckpt.restore(latest_checkpoint).expect_partial()
if eval_config.use_moving_averages:
optimizer.swap_weights()
summary_writer = tf.compat.v2.summary.create_file_writer(
os.path.join(model_dir, 'eval', eval_input_config.name))
with summary_writer.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment