Commit ee968bad authored by Asim Shankar's avatar Asim Shankar
Browse files

official/mnist: Updates with the release of TensorFlow 1.8.

parent 505f554c
......@@ -63,7 +63,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# can be computed.
with tfe.GradientTape() as tape:
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss_value = loss(logits, labels)
tf.contrib.summary.scalar('loss', loss_value)
......@@ -99,11 +99,11 @@ def main(argv):
parser = MNISTEagerArgParser()
flags = parser.parse_args(args=argv[1:])
tfe.enable_eager_execution()
tf.enable_eager_execution()
# Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first')
if flags.no_gpu or tfe.num_gpus() <= 0:
if flags.no_gpu or not tf.test.is_gpu_available():
(device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment