Commit 9691ef7a authored by minoring's avatar minoring
Browse files

Add compat.v1 to support TF 2.0 in mnist

parent 06412123
......@@ -35,7 +35,7 @@ def read32(bytestream):
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_images, unused
rows = read32(f)
......@@ -51,7 +51,7 @@ def check_image_file_header(filename):
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_items, unused
if magic != 2049:
......@@ -62,17 +62,17 @@ def check_labels_file_header(filename):
def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
if tf.io.gfile.exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
if not tf.io.gfile.exists(directory):
tf.io.gfile.mkdir(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.gfile.Open(filepath, 'wb') as f_out:
tf.io.gfile.GFile(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
......@@ -89,13 +89,13 @@ def dataset(directory, images_file, labels_file):
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.io.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.io.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar
return tf.cast(label, tf.int32)
......
......@@ -125,11 +125,12 @@ def model_fn(features, labels, mode, params):
'classify': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(labels=labels,
logits=logits)
accuracy = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook.
......@@ -143,7 +144,8 @@ def model_fn(features, labels, mode, params):
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
train_op=optimizer.minimize(loss,
tf.compat.v1.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......@@ -166,7 +168,7 @@ def run_mnist(flags_obj):
model_helpers.apply_clean(flags_obj)
model_function = model_fn
session_config = tf.ConfigProto(
session_config = tf.compat.v1.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
......@@ -227,7 +229,7 @@ def run_mnist(flags_obj):
# Export the model
if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
image = tf.compat.v1.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
......@@ -240,6 +242,6 @@ def main(_):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_mnist_flags()
absl_app.run(main)
......@@ -29,8 +29,8 @@ BATCH_SIZE = 100
def dummy_input_fn():
image = tf.random_uniform([BATCH_SIZE, 784])
labels = tf.random_uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32)
image = tf.random.uniform([BATCH_SIZE, 784])
labels = tf.random.uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32)
return image, labels
......@@ -64,7 +64,7 @@ class Tests(tf.test.TestCase):
self.assertEqual(2, global_step)
self.assertEqual(accuracy.shape, ())
input_fn = lambda: tf.random_uniform([3, 784])
input_fn = lambda: tf.random.uniform([3, 784])
predictions_generator = classifier.predict(input_fn)
for _ in range(3):
predictions = next(predictions_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment