Unverified Commit 95385809 authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #3535 from allenlavoie/master

Update the eager MNIST example to use object-based checkpointing
parents d4a4dd04 40f8e23e
......@@ -53,14 +53,13 @@ def compute_accuracy(logits, labels):
tf.cast(tf.equal(predictions, labels), dtype=tf.float32)) / batch_size
def train(model, optimizer, dataset, log_interval=None):
def train(model, optimizer, dataset, step_counter, log_interval=None):
"""Trains model on `dataset` using `optimizer`."""
global_step = tf.train.get_or_create_global_step()
start = time.time()
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
with tf.contrib.summary.record_summaries_every_n_global_steps(10):
with tf.contrib.summary.record_summaries_every_n_global_steps(
10, global_step=step_counter):
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# can be computed.
......@@ -71,7 +70,7 @@ def train(model, optimizer, dataset, log_interval=None):
tf.contrib.summary.scalar('accuracy', compute_accuracy(logits, labels))
grads = tape.gradient(loss_value, model.variables)
optimizer.apply_gradients(
zip(grads, model.variables), global_step=global_step)
zip(grads, model.variables), global_step=step_counter)
if log_interval and batch % log_interval == 0:
rate = log_interval / (time.time() - start)
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
......@@ -128,23 +127,25 @@ def main(_):
test_summary_writer = tf.contrib.summary.create_file_writer(
test_dir, flush_millis=10000, name='test')
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
# Train and evaluate for 11 epochs.
step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
# Train and evaluate for 10 epochs.
with tf.device(device):
for epoch in range(1, 11):
with tfe.restore_variables_on_create(
tf.train.latest_checkpoint(FLAGS.checkpoint_dir)):
global_step = tf.train.get_or_create_global_step()
for _ in range(10):
start = time.time()
with summary_writer.as_default():
train(model, optimizer, train_ds, FLAGS.log_interval)
train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
end = time.time()
print('\nTrain time for epoch #%d (global step %d): %f' %
(epoch, global_step.numpy(), end - start))
print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1,
step_counter.numpy(),
end - start))
with test_summary_writer.as_default():
test(model, test_ds)
all_variables = (model.variables + optimizer.variables() + [global_step])
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
checkpoint.save(checkpoint_prefix)
if __name__ == '__main__':
......
......@@ -46,7 +46,8 @@ def train(defun=False):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
dataset = random_dataset()
with tf.device(device()):
mnist_eager.train(model, optimizer, dataset)
mnist_eager.train(model, optimizer, dataset,
step_counter=tf.train.get_or_create_global_step())
def evaluate(defun=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment