Unverified Commit ea7481c8 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #3510 from tensorflow/tpu_mnist_cluster_resolver

Upgrade mnist_tpu.py to use the new TPUClusterResolver.
parents 6a84aa6e 9c03af08
......@@ -27,19 +27,22 @@ import tensorflow as tf
import dataset
import mnist
# Cloud TPU Cluster Resolvers
# Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string(
"gcp_project", default=None,
help="Project name for the Cloud TPU-enabled project. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
"tpu", default=None,
help="The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", default=None,
help="GCE zone where the Cloud TPU is located in. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
help="[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"tpu_name", default=None,
help="Name of the Cloud TPU for Cluster Resolvers. You must specify either "
"this flag or --master.")
"gcp_project", default=None,
help="[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
# Model specific parameters
tf.flags.DEFINE_string(
......@@ -74,6 +77,8 @@ def metric_fn(labels, logits):
def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits."""
del params
if mode == tf.estimator.ModeKeys.PREDICT:
raise RuntimeError("mode {} is not supported yet".format(mode))
......@@ -105,6 +110,7 @@ def model_fn(features, labels, mode, params):
def train_input_fn(params):
"""train_input_fn defines the input pipeline used for training."""
batch_size = params["batch_size"]
data_dir = params["data_dir"]
# Retrieves the batch size for the current shard. The # of shards is
......@@ -130,25 +136,11 @@ def main(argv):
del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.master is None and FLAGS.tpu_name is None:
raise RuntimeError("You must specify either --master or --tpu_name.")
if FLAGS.master is not None:
if FLAGS.tpu_name is not None:
tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
"--tpu_name and using --master.")
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
run_config = tf.contrib.tpu.RunConfig(
master=tpu_grpc_url,
evaluation_master=tpu_grpc_url,
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment