"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "208611a4068617f3c8c9b1796ba742c3e8a0dc11"
Unverified Commit 0864b2a4 authored by Niru Maheswaranathan's avatar Niru Maheswaranathan Committed by GitHub
Browse files

Merge pull request #5084 from yesmung/master

Modify flag name for the checkpoint path
parents 468d8bb6 2949cfd8
...@@ -35,13 +35,13 @@ import sonnet as snt ...@@ -35,13 +35,13 @@ import sonnet as snt
from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.framework import checkpoint_utils
flags.DEFINE_string("checkpoint", None, "Dir to load pretrained update rule from") flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory") flags.DEFINE_string("train_log_dir", None, "Training log directory")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
...@@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
summary_op = tf.summary.merge_all() summary_op = tf.summary.merge_all()
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint: if checkpoint_dir:
str_var_list = checkpoint_utils.list_variables(checkpoint) str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
name_to_v_map = {v.op.name: v for v in tf.all_variables()} name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [ var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
...@@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
# global step should be restored from the evals job checkpoint or zero for fresh. # global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step) step = sess.run(global_step)
if step == 0 and checkpoint: if step == 0 and checkpoint_dir:
tf.logging.info("force restore") tf.logging.info("force restore")
saver.restore(sess, checkpoint) saver.restore(sess, checkpoint_dir)
tf.logging.info("force restore done") tf.logging.info("force restore done")
sess.run(reset_global_step) sess.run(reset_global_step)
step = sess.run(global_step) step = sess.run(global_step)
...@@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
def main(argv): def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint) train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment