Commit 9691ef7a authored by minoring's avatar minoring
Browse files

Add compat.v1 to support TF 2.0 in mnist

parent 06412123
...@@ -35,7 +35,7 @@ def read32(bytestream): ...@@ -35,7 +35,7 @@ def read32(bytestream):
def check_image_file_header(filename): def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset.""" """Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f: with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f) magic = read32(f)
read32(f) # num_images, unused read32(f) # num_images, unused
rows = read32(f) rows = read32(f)
...@@ -51,7 +51,7 @@ def check_image_file_header(filename): ...@@ -51,7 +51,7 @@ def check_image_file_header(filename):
def check_labels_file_header(filename): def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset.""" """Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f: with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f) magic = read32(f)
read32(f) # num_items, unused read32(f) # num_items, unused
if magic != 2049: if magic != 2049:
...@@ -62,17 +62,17 @@ def check_labels_file_header(filename): ...@@ -62,17 +62,17 @@ def check_labels_file_header(filename):
def download(directory, filename): def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done.""" """Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename) filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath): if tf.io.gfile.exists(filepath):
return filepath return filepath
if not tf.gfile.Exists(directory): if not tf.io.gfile.exists(directory):
tf.gfile.MakeDirs(directory) tf.io.gfile.mkdir(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz') _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath)) print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath) urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, \ with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.gfile.Open(filepath, 'wb') as f_out: tf.io.gfile.GFile(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath) os.remove(zipped_filepath)
return filepath return filepath
...@@ -89,13 +89,13 @@ def dataset(directory, images_file, labels_file): ...@@ -89,13 +89,13 @@ def dataset(directory, images_file, labels_file):
def decode_image(image): def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0] # Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8) image = tf.io.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32) image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784]) image = tf.reshape(image, [784])
return image / 255.0 return image / 255.0
def decode_label(label): def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] label = tf.io.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar label = tf.reshape(label, []) # label is a scalar
return tf.cast(label, tf.int32) return tf.cast(label, tf.int32)
......
...@@ -125,11 +125,12 @@ def model_fn(features, labels, mode, params): ...@@ -125,11 +125,12 @@ def model_fn(features, labels, mode, params):
'classify': tf.estimator.export.PredictOutput(predictions) 'classify': tf.estimator.export.PredictOutput(predictions)
}) })
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
logits = model(image, training=True) logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(labels=labels,
accuracy = tf.metrics.accuracy( logits=logits)
accuracy = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)) labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook. # Name tensors to be logged with LoggingTensorHook.
...@@ -143,7 +144,8 @@ def model_fn(features, labels, mode, params): ...@@ -143,7 +144,8 @@ def model_fn(features, labels, mode, params):
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
loss=loss, loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())) train_op=optimizer.minimize(loss,
tf.compat.v1.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL: if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False) logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
...@@ -166,7 +168,7 @@ def run_mnist(flags_obj): ...@@ -166,7 +168,7 @@ def run_mnist(flags_obj):
model_helpers.apply_clean(flags_obj) model_helpers.apply_clean(flags_obj)
model_function = model_fn model_function = model_fn
session_config = tf.ConfigProto( session_config = tf.compat.v1.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
...@@ -227,7 +229,7 @@ def run_mnist(flags_obj): ...@@ -227,7 +229,7 @@ def run_mnist(flags_obj):
# Export the model # Export the model
if flags_obj.export_dir is not None: if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.compat.v1.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image, 'image': image,
}) })
...@@ -240,6 +242,6 @@ def main(_): ...@@ -240,6 +242,6 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_mnist_flags() define_mnist_flags()
absl_app.run(main) absl_app.run(main)
...@@ -29,8 +29,8 @@ BATCH_SIZE = 100 ...@@ -29,8 +29,8 @@ BATCH_SIZE = 100
def dummy_input_fn(): def dummy_input_fn():
image = tf.random_uniform([BATCH_SIZE, 784]) image = tf.random.uniform([BATCH_SIZE, 784])
labels = tf.random_uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32) labels = tf.random.uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32)
return image, labels return image, labels
...@@ -64,7 +64,7 @@ class Tests(tf.test.TestCase): ...@@ -64,7 +64,7 @@ class Tests(tf.test.TestCase):
self.assertEqual(2, global_step) self.assertEqual(2, global_step)
self.assertEqual(accuracy.shape, ()) self.assertEqual(accuracy.shape, ())
input_fn = lambda: tf.random_uniform([3, 784]) input_fn = lambda: tf.random.uniform([3, 784])
predictions_generator = classifier.predict(input_fn) predictions_generator = classifier.predict(input_fn)
for _ in range(3): for _ in range(3):
predictions = next(predictions_generator) predictions = next(predictions_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment