Unverified Commit 7922c9eb authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #4853 from tensorflow/mnist_metrics_dtype

Use float32 metrics in mnist_eager
parents 05ec6d87 dfafba4a
...@@ -83,8 +83,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None): ...@@ -83,8 +83,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
def test(model, dataset): def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`.""" """Perform an evaluation of `model` on the examples from `dataset`."""
avg_loss = tfe.metrics.Mean('loss') avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32)
accuracy = tfe.metrics.Accuracy('accuracy') accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32)
for (images, labels) in dataset: for (images, labels) in dataset:
logits = model(images, training=False) logits = model(images, training=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment