Commit bf2d1354 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by TF Object Detection Team
Browse files

Fix the typo-ed name "remove_unecessary_ema" -> "remove_unnecessary_ema" and...

Fix the typo-ed name "remove_unecessary_ema" -> "remove_unnecessary_ema" and also re-implement it with deep dictionary copy.

PiperOrigin-RevId: 368338501
parent 3dacd474
......@@ -170,7 +170,7 @@ def replace_variable_values_with_moving_averages(graph,
with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore()
ema_variables_to_restore = config_util.remove_unecessary_ema(
ema_variables_to_restore = config_util.remove_unnecessary_ema(
ema_variables_to_restore, no_ema_collection)
with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore)
......
......@@ -1042,7 +1042,7 @@ def _update_retain_original_image_additional_channels(
retain_original_image_additional_channels)
def remove_unecessary_ema(variables_to_restore, no_ema_collection=None):
def remove_unnecessary_ema(variables_to_restore, no_ema_collection=None):
"""Remap and Remove EMA variable that are not created during training.
ExponentialMovingAverage.variables_to_restore() returns a map of EMA names
......@@ -1054,9 +1054,8 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None):
}
This function takes care of the extra ExponentialMovingAverage variables
that get created during eval but aren't available in the checkpoint, by
remapping the key to the shallow copy of the variable itself, and remove
the entry of its EMA from the variables to restore. An example resulting
dictionary would look like:
remapping the key to the variable itself, and remove the entry of its EMA from
the variables to restore. An example resulting dictionary would look like:
{
conv/batchnorm/gamma: conv/batchnorm/gamma,
conv_4/conv2d_params: conv_4/conv2d_params,
......@@ -1075,14 +1074,15 @@ def remove_unecessary_ema(variables_to_restore, no_ema_collection=None):
if no_ema_collection is None:
return variables_to_restore
restore_map = {}
for key in variables_to_restore:
if "ExponentialMovingAverage" in key:
for name in no_ema_collection:
if name in key:
variables_to_restore[key.replace("/ExponentialMovingAverage",
"")] = variables_to_restore[key]
del variables_to_restore[key]
return variables_to_restore
if ("ExponentialMovingAverage" in key
and any([name in key for name in no_ema_collection])):
new_key = key.replace("/ExponentialMovingAverage", "")
else:
new_key = key
restore_map[new_key] = variables_to_restore[key]
return restore_map
def _update_num_classes(model_config, num_classes):
......
......@@ -985,7 +985,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual(config_util.get_number_of_classes(configs["model"]), 2)
def testRemoveUnecessaryEma(self):
def testRemoveUnnecessaryEma(self):
input_dict = {
"expanded_conv_10/project/act_quant/min":
1,
......@@ -1016,7 +1016,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual(
output_dict,
config_util.remove_unecessary_ema(input_dict, no_ema_collection))
config_util.remove_unnecessary_ema(input_dict, no_ema_collection))
def testUpdateRescoreInstances(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment