Commit cd5e9b7c authored by Christopher Shallue's avatar Christopher Shallue
Browse files

Fix a bug in the im2txt code where the Saver is created before the

optimizer.
parent 71f239fd
...@@ -77,10 +77,6 @@ class ModelConfig(object): ...@@ -77,10 +77,6 @@ class ModelConfig(object):
# If < 1.0, the dropout keep probability applied to LSTM variables. # If < 1.0, the dropout keep probability applied to LSTM variables.
self.lstm_dropout_keep_prob = 0.7 self.lstm_dropout_keep_prob = 0.7
# How many model checkpoints to keep.
self.max_checkpoints_to_keep = 5
self.keep_checkpoint_every_n_hours = 10000
class TrainingConfig(object): class TrainingConfig(object):
"""Wrapper class for training hyperparameters.""" """Wrapper class for training hyperparameters."""
...@@ -103,3 +99,6 @@ class TrainingConfig(object): ...@@ -103,3 +99,6 @@ class TrainingConfig(object):
# If not None, clip gradients to this value. # If not None, clip gradients to this value.
self.clip_gradients = 5.0 self.clip_gradients = 5.0
# How many model checkpoints to keep.
self.max_checkpoints_to_keep = 5
...@@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op): ...@@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op):
global_step) global_step)
def run_once(model, summary_writer, summary_op): def run_once(model, saver, summary_writer, summary_op):
"""Evaluates the latest model checkpoint. """Evaluates the latest model checkpoint.
Args: Args:
model: Instance of ShowAndTellModel; the model to evaluate. model: Instance of ShowAndTellModel; the model to evaluate.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of SummaryWriter. summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries. summary_op: Op for generating model summaries.
""" """
...@@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op): ...@@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op):
with tf.Session() as sess: with tf.Session() as sess:
# Load model from checkpoint. # Load model from checkpoint.
tf.logging.info("Loading model from checkpoint: %s", model_path) tf.logging.info("Loading model from checkpoint: %s", model_path)
model.saver.restore(sess, model_path) saver.restore(sess, model_path)
global_step = tf.train.global_step(sess, model.global_step.name) global_step = tf.train.global_step(sess, model.global_step.name)
tf.logging.info("Successfully loaded %s at global step = %d.", tf.logging.info("Successfully loaded %s at global step = %d.",
os.path.basename(model_path), global_step) os.path.basename(model_path), global_step)
...@@ -166,6 +167,9 @@ def run(): ...@@ -166,6 +167,9 @@ def run():
model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval") model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
model.build() model.build()
# Create the Saver to restore model Variables.
saver = tf.train.Saver()
# Create the summary operation and the summary writer. # Create the summary operation and the summary writer.
summary_op = tf.merge_all_summaries() summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(eval_dir) summary_writer = tf.train.SummaryWriter(eval_dir)
...@@ -177,7 +181,7 @@ def run(): ...@@ -177,7 +181,7 @@ def run():
start = time.time() start = time.time()
tf.logging.info("Starting evaluation at " + time.strftime( tf.logging.info("Starting evaluation at " + time.strftime(
"%Y-%m-%d-%H:%M:%S", time.localtime())) "%Y-%m-%d-%H:%M:%S", time.localtime()))
run_once(model, summary_writer, summary_op) run_once(model, saver, summary_writer, summary_op)
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0: if time_to_next_eval > 0:
time.sleep(time_to_next_eval) time.sleep(time_to_next_eval)
......
...@@ -112,10 +112,8 @@ class InferenceWrapperBase(object): ...@@ -112,10 +112,8 @@ class InferenceWrapperBase(object):
from the checkpoint file. from the checkpoint file.
""" """
tf.logging.info("Building model.") tf.logging.info("Building model.")
model = self.build_model(model_config) self.build_model(model_config)
saver = model.saver saver = tf.train.Saver()
if not saver:
saver = tf.Saver()
return self._create_restore_fn(checkpoint_path, saver) return self._create_restore_fn(checkpoint_path, saver)
......
...@@ -347,12 +347,6 @@ class ShowAndTellModel(object): ...@@ -347,12 +347,6 @@ class ShowAndTellModel(object):
self.global_step = global_step self.global_step = global_step
def setup_saver(self):
"""Sets up the Saver for loading and saving model checkpoints."""
self.saver = tf.train.Saver(
max_to_keep=self.config.max_checkpoints_to_keep,
keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours)
def build(self): def build(self):
"""Creates all ops for training and evaluation.""" """Creates all ops for training and evaluation."""
self.build_inputs() self.build_inputs()
...@@ -361,4 +355,3 @@ class ShowAndTellModel(object): ...@@ -361,4 +355,3 @@ class ShowAndTellModel(object):
self.build_model() self.build_model()
self.setup_inception_initializer() self.setup_inception_initializer()
self.setup_global_step() self.setup_global_step()
self.setup_saver()
...@@ -95,6 +95,9 @@ def main(unused_argv): ...@@ -95,6 +95,9 @@ def main(unused_argv):
clip_gradients=training_config.clip_gradients, clip_gradients=training_config.clip_gradients,
learning_rate_decay_fn=learning_rate_decay_fn) learning_rate_decay_fn=learning_rate_decay_fn)
# Set up the Saver for saving and restoring model checkpoints.
saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
# Run training. # Run training.
tf.contrib.slim.learning.train( tf.contrib.slim.learning.train(
train_op, train_op,
...@@ -104,7 +107,7 @@ def main(unused_argv): ...@@ -104,7 +107,7 @@ def main(unused_argv):
global_step=model.global_step, global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps, number_of_steps=FLAGS.number_of_steps,
init_fn=model.init_fn, init_fn=model.init_fn,
saver=model.saver) saver=saver)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment