Unverified Commit 68a9002f authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #4827 from asimshankar/mnist-eager

[official/mnist]: Avoid some now unnecessary 'tfe' symbols.
parents 71c196c1 612ec83d
...@@ -33,7 +33,6 @@ import time ...@@ -33,7 +33,6 @@ import time
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.eager as tfe
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.mnist import dataset as mnist_dataset from official.mnist import dataset as mnist_dataset
...@@ -42,6 +41,8 @@ from official.utils.flags import core as flags_core ...@@ -42,6 +41,8 @@ from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
tfe = tf.contrib.eager
def loss(logits, labels): def loss(logits, labels):
return tf.reduce_mean( return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits( tf.nn.sparse_softmax_cross_entropy_with_logits(
...@@ -60,7 +61,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None): ...@@ -60,7 +61,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
"""Trains model on `dataset` using `optimizer`.""" """Trains model on `dataset` using `optimizer`."""
start = time.time() start = time.time()
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): for (batch, (images, labels)) in enumerate(dataset):
with tf.contrib.summary.record_summaries_every_n_global_steps( with tf.contrib.summary.record_summaries_every_n_global_steps(
10, global_step=step_counter): 10, global_step=step_counter):
# Record the operations used to compute the loss given the input, # Record the operations used to compute the loss given the input,
...@@ -85,7 +86,7 @@ def test(model, dataset): ...@@ -85,7 +86,7 @@ def test(model, dataset):
avg_loss = tfe.metrics.Mean('loss') avg_loss = tfe.metrics.Mean('loss')
accuracy = tfe.metrics.Accuracy('accuracy') accuracy = tfe.metrics.Accuracy('accuracy')
for (images, labels) in tfe.Iterator(dataset): for (images, labels) in dataset:
logits = model(images, training=False) logits = model(images, training=False)
avg_loss(loss(logits, labels)) avg_loss(loss(logits, labels))
accuracy( accuracy(
...@@ -145,7 +146,7 @@ def run_mnist_eager(flags_obj): ...@@ -145,7 +146,7 @@ def run_mnist_eager(flags_obj):
# Create and restore checkpoint (if one exists on the path) # Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt') checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step() step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint( checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter) model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists. # Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir)) checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment