Commit 6a7de210 authored by Mark Daoust's avatar Mark Daoust
Browse files

Clear softmax_cross_entropy deprecation warning.

The `sparse` version is more efficient anyway.

I'm returning the labels shape [1] instead of []
because tf.accuracy fails otherwise.
parent 7d8a2c0d
......@@ -89,15 +89,14 @@ def dataset(directory, images_file, labels_file):
image = tf.reshape(image, [784])
return image / 255.0
def one_hot_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
label = tf.reshape(label, []) # label is a scalar
return tf.one_hot(label, 10)
def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(one_hot_label)
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
......
......@@ -102,9 +102,9 @@ def model_fn(features, labels, mode, params):
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
logits = model(image, training=True)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name the accuracy tensor 'train_accuracy' to demonstrate the
# LoggingTensorHook.
tf.identity(accuracy[1], name='train_accuracy')
......@@ -115,7 +115,7 @@ def model_fn(features, labels, mode, params):
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
......
......@@ -27,8 +27,8 @@ BATCH_SIZE = 100
def dummy_input_fn():
image = tf.random_uniform([BATCH_SIZE, 784])
labels = tf.random_uniform([BATCH_SIZE], maxval=9, dtype=tf.int32)
return image, tf.one_hot(labels, 10)
labels = tf.random_uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32)
return image, labels
def make_estimator():
......
......@@ -50,7 +50,7 @@ FLAGS = tf.flags.FLAGS
def metric_fn(labels, logits):
accuracy = tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
labels=labels, predictions=tf.argmax(logits, axis=1))
return {"accuracy": accuracy}
......@@ -64,7 +64,7 @@ def model_fn(features, labels, mode, params):
model = mnist.Model("channels_last")
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = tf.train.exponential_decay(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment