Commit 2949cfd8 authored by MyungSung Kwak's avatar MyungSung Kwak
Browse files

Modify flag name for the checkpoint path



change flag name to checkpoint_dir according to the variable name
used by the checkpoint_utils within tensorflow python framework.

The important point is that when run the run_eval script, an error
occurs due to the different flag name.
Signed-off-by: default avatarMyungSung Kwak <yesmung@gmail.com>
parent 5be37277
...@@ -35,13 +35,13 @@ import sonnet as snt ...@@ -35,13 +35,13 @@ import sonnet as snt
from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.framework import checkpoint_utils
flags.DEFINE_string("checkpoint", None, "Dir to load pretrained update rule from") flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory") flags.DEFINE_string("train_log_dir", None, "Training log directory")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
...@@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
summary_op = tf.summary.merge_all() summary_op = tf.summary.merge_all()
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint: if checkpoint_dir:
str_var_list = checkpoint_utils.list_variables(checkpoint) str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
name_to_v_map = {v.op.name: v for v in tf.all_variables()} name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [ var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
...@@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
# global step should be restored from the evals job checkpoint or zero for fresh. # global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step) step = sess.run(global_step)
if step == 0 and checkpoint: if step == 0 and checkpoint_dir:
tf.logging.info("force restore") tf.logging.info("force restore")
saver.restore(sess, checkpoint) saver.restore(sess, checkpoint_dir)
tf.logging.info("force restore done") tf.logging.info("force restore done")
sess.run(reset_global_step) sess.run(reset_global_step)
step = sess.run(global_step) step = sess.run(global_step)
...@@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): ...@@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
def main(argv): def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint) train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment