Commit 93b8168a authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Switch line orders in trainer so that restore_map is called after moving...

Switch line orders in trainer so that restore_map is called after moving average variables are created.

Moving averages are now properly loaded during fine-tuning, instead of being recreated.

PiperOrigin-RevId: 190496046
parent ed877c00
......@@ -264,29 +264,6 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
total_num_replicas=worker_replicas)
sync_optimizer = training_optimizer
# Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
if not train_config.fine_tune_checkpoint_type:
# train_config.from_detection_checkpoint field is deprecated. For
# backward compatibility, fine_tune_checkpoint_type is set based on
# from_detection_checkpoint.
if train_config.from_detection_checkpoint:
train_config.fine_tune_checkpoint_type = 'detection'
else:
train_config.fine_tune_checkpoint_type = 'classification'
var_map = detection_model.restore_map(
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
load_all_detection_checkpoint_vars=(
train_config.load_all_detection_checkpoint_vars))
available_var_map = (variables_helper.
get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint))
init_saver = tf.train.Saver(available_var_map)
def initializer_fn(sess):
init_saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()):
regularization_losses = (None if train_config.add_regularization_loss
else [])
......@@ -354,6 +331,29 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
saver = tf.train.Saver(
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
# Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
if not train_config.fine_tune_checkpoint_type:
# train_config.from_detection_checkpoint field is deprecated. For
# backward compatibility, fine_tune_checkpoint_type is set based on
# from_detection_checkpoint.
if train_config.from_detection_checkpoint:
train_config.fine_tune_checkpoint_type = 'detection'
else:
train_config.fine_tune_checkpoint_type = 'classification'
var_map = detection_model.restore_map(
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
load_all_detection_checkpoint_vars=(
train_config.load_all_detection_checkpoint_vars))
available_var_map = (variables_helper.
get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint))
init_saver = tf.train.Saver(available_var_map)
def initializer_fn(sess):
init_saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
slim.learning.train(
train_tensor,
logdir=train_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment