Commit bf2d1354 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by TF Object Detection Team
Browse files

Fix the typo-ed name "remove_unecessary_ema" -> "remove_unnecessary_ema" and...

Fix the typo-ed name "remove_unecessary_ema" -> "remove_unnecessary_ema" and also re-implement it with deep dictionary copy.

PiperOrigin-RevId: 368338501
parent 3dacd474
...@@ -170,7 +170,7 @@ def replace_variable_values_with_moving_averages(graph, ...@@ -170,7 +170,7 @@ def replace_variable_values_with_moving_averages(graph,
with graph.as_default(): with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0) variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore() ema_variables_to_restore = variable_averages.variables_to_restore()
ema_variables_to_restore = config_util.remove_unecessary_ema( ema_variables_to_restore = config_util.remove_unnecessary_ema(
ema_variables_to_restore, no_ema_collection) ema_variables_to_restore, no_ema_collection)
with tf.Session() as sess: with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore) read_saver = tf.train.Saver(ema_variables_to_restore)
......
...@@ -1042,7 +1042,7 @@ def _update_retain_original_image_additional_channels( ...@@ -1042,7 +1042,7 @@ def _update_retain_original_image_additional_channels(
retain_original_image_additional_channels) retain_original_image_additional_channels)
def remove_unecessary_ema(variables_to_restore, no_ema_collection=None): def remove_unnecessary_ema(variables_to_restore, no_ema_collection=None):
"""Remap and Remove EMA variable that are not created during training. """Remap and Remove EMA variable that are not created during training.
ExponentialMovingAverage.variables_to_restore() returns a map of EMA names ExponentialMovingAverage.variables_to_restore() returns a map of EMA names
...@@ -1054,9 +1054,8 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None): ...@@ -1054,9 +1054,8 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None):
} }
This function takes care of the extra ExponentialMovingAverage variables This function takes care of the extra ExponentialMovingAverage variables
that get created during eval but aren't available in the checkpoint, by that get created during eval but aren't available in the checkpoint, by
remapping the key to the shallow copy of the variable itself, and remove remapping the key to the variable itself, and remove the entry of its EMA from
the entry of its EMA from the variables to restore. An example resulting the variables to restore. An example resulting dictionary would look like:
dictionary would look like:
{ {
conv/batchnorm/gamma: conv/batchnorm/gamma, conv/batchnorm/gamma: conv/batchnorm/gamma,
conv_4/conv2d_params: conv_4/conv2d_params, conv_4/conv2d_params: conv_4/conv2d_params,
...@@ -1075,14 +1074,15 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None): ...@@ -1075,14 +1074,15 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None):
if no_ema_collection is None: if no_ema_collection is None:
return variables_to_restore return variables_to_restore
restore_map = {}
for key in variables_to_restore: for key in variables_to_restore:
if "ExponentialMovingAverage" in key: if ("ExponentialMovingAverage" in key
for name in no_ema_collection: and any([name in key for name in no_ema_collection])):
if name in key: new_key = key.replace("/ExponentialMovingAverage", "")
variables_to_restore[key.replace("/ExponentialMovingAverage", else:
"")] = variables_to_restore[key] new_key = key
del variables_to_restore[key] restore_map[new_key] = variables_to_restore[key]
return variables_to_restore return restore_map
def _update_num_classes(model_config, num_classes): def _update_num_classes(model_config, num_classes):
......
...@@ -985,7 +985,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -985,7 +985,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual(config_util.get_number_of_classes(configs["model"]), 2) self.assertEqual(config_util.get_number_of_classes(configs["model"]), 2)
def testRemoveUnecessaryEma(self): def testRemoveUnnecessaryEma(self):
input_dict = { input_dict = {
"expanded_conv_10/project/act_quant/min": "expanded_conv_10/project/act_quant/min":
1, 1,
...@@ -1016,7 +1016,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -1016,7 +1016,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
output_dict, output_dict,
config_util.remove_unecessary_ema(input_dict, no_ema_collection)) config_util.remove_unnecessary_ema(input_dict, no_ema_collection))
def testUpdateRescoreInstances(self): def testUpdateRescoreInstances(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment