Commit 6765b16d authored by tranvohuy's avatar tranvohuy Committed by Paige Bailey
Browse files

change tf.to_int32 to tf.cast (#6359)

tf.to_int32 raise deprecated warning.
change tf.to_int32(labels) to tf.cast(labels, tf.int32)
parent 1aa241d9
...@@ -97,7 +97,7 @@ def dataset(directory, images_file, labels_file): ...@@ -97,7 +97,7 @@ def dataset(directory, images_file, labels_file):
def decode_label(label): def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar label = tf.reshape(label, []) # label is a scalar
return tf.to_int32(label) return tf.cast(label, tf.int32)
images = tf.data.FixedLengthRecordDataset( images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image) images_file, 28 * 28, header_bytes=16).map(decode_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment