Commit 40f8e23e authored by Allen Lavoie's avatar Allen Lavoie
Browse files

Update the eager MNIST example to use object-based checkpointing

parent d4a4dd04
...@@ -53,14 +53,13 @@ def compute_accuracy(logits, labels): ...@@ -53,14 +53,13 @@ def compute_accuracy(logits, labels):
tf.cast(tf.equal(predictions, labels), dtype=tf.float32)) / batch_size tf.cast(tf.equal(predictions, labels), dtype=tf.float32)) / batch_size
def train(model, optimizer, dataset, log_interval=None): def train(model, optimizer, dataset, step_counter, log_interval=None):
"""Trains model on `dataset` using `optimizer`.""" """Trains model on `dataset` using `optimizer`."""
global_step = tf.train.get_or_create_global_step()
start = time.time() start = time.time()
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
with tf.contrib.summary.record_summaries_every_n_global_steps(10): with tf.contrib.summary.record_summaries_every_n_global_steps(
10, global_step=step_counter):
# Record the operations used to compute the loss given the input, # Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables # so that the gradient of the loss with respect to the variables
# can be computed. # can be computed.
...@@ -71,7 +70,7 @@ def train(model, optimizer, dataset, log_interval=None): ...@@ -71,7 +70,7 @@ def train(model, optimizer, dataset, log_interval=None):
tf.contrib.summary.scalar('accuracy', compute_accuracy(logits, labels)) tf.contrib.summary.scalar('accuracy', compute_accuracy(logits, labels))
grads = tape.gradient(loss_value, model.variables) grads = tape.gradient(loss_value, model.variables)
optimizer.apply_gradients( optimizer.apply_gradients(
zip(grads, model.variables), global_step=global_step) zip(grads, model.variables), global_step=step_counter)
if log_interval and batch % log_interval == 0: if log_interval and batch % log_interval == 0:
rate = log_interval / (time.time() - start) rate = log_interval / (time.time() - start)
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate)) print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
...@@ -128,23 +127,25 @@ def main(_): ...@@ -128,23 +127,25 @@ def main(_):
test_summary_writer = tf.contrib.summary.create_file_writer( test_summary_writer = tf.contrib.summary.create_file_writer(
test_dir, flush_millis=10000, name='test') test_dir, flush_millis=10000, name='test')
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step()
# Train and evaluate for 11 epochs. checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
# Train and evaluate for 10 epochs.
with tf.device(device): with tf.device(device):
for epoch in range(1, 11): for _ in range(10):
with tfe.restore_variables_on_create( start = time.time()
tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): with summary_writer.as_default():
global_step = tf.train.get_or_create_global_step() train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
start = time.time() end = time.time()
with summary_writer.as_default(): print('\nTrain time for epoch #%d (%d total steps): %f' %
train(model, optimizer, train_ds, FLAGS.log_interval) (checkpoint.save_counter.numpy() + 1,
end = time.time() step_counter.numpy(),
print('\nTrain time for epoch #%d (global step %d): %f' % end - start))
(epoch, global_step.numpy(), end - start))
with test_summary_writer.as_default(): with test_summary_writer.as_default():
test(model, test_ds) test(model, test_ds)
all_variables = (model.variables + optimizer.variables() + [global_step]) checkpoint.save(checkpoint_prefix)
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -46,7 +46,8 @@ def train(defun=False): ...@@ -46,7 +46,8 @@ def train(defun=False):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
dataset = random_dataset() dataset = random_dataset()
with tf.device(device()): with tf.device(device()):
mnist_eager.train(model, optimizer, dataset) mnist_eager.train(model, optimizer, dataset,
step_counter=tf.train.get_or_create_global_step())
def evaluate(defun=False): def evaluate(defun=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment