Unverified Commit 87a9205b authored by Frank Chen's avatar Frank Chen Committed by GitHub
Browse files

Add ClusterResolver capability to mnist_tpu.py

parent 2ea91716
...@@ -27,6 +27,26 @@ import tensorflow as tf ...@@ -27,6 +27,26 @@ import tensorflow as tf
import dataset import dataset
import mnist import mnist
# Cloud TPU Cluster Resolvers
tf.flags.DEFINE_string(
"gcp_project", default=None,
help="Project name for the Cloud TPU-enabled project. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
tf.flags.DEFINE_string(
"tpu_zone", default=None,
help="GCE zone where the Cloud TPU is located in. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
tf.flags.DEFINE_string(
"tpu_name", default=None,
help="Name of the Cloud TPU for Cluster Resolvers. You must specify either "
"this flag or --master.")
# Model specific paramenters
tf.flags.DEFINE_string(
"master", default=None,
help="GRPC URL of the master (e.g. grpc://ip.address.of.tpu:8470). You "
"must specify either this flag or --tpu_name.")
tf.flags.DEFINE_string("data_dir", "", tf.flags.DEFINE_string("data_dir", "",
"Path to directory containing the MNIST dataset") "Path to directory containing the MNIST dataset")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir") tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
...@@ -40,7 +60,6 @@ tf.flags.DEFINE_integer("eval_steps", 0, ...@@ -40,7 +60,6 @@ tf.flags.DEFINE_integer("eval_steps", 0,
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.") tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs") tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
tf.flags.DEFINE_string("master", "local", "GRPC URL of the Cloud TPU instance.")
tf.flags.DEFINE_integer("iterations", 50, tf.flags.DEFINE_integer("iterations", 50,
"Number of iterations per TPU training loop.") "Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).") tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")
...@@ -111,9 +130,25 @@ def main(argv): ...@@ -111,9 +130,25 @@ def main(argv):
del argv # Unused. del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.master is None and FLAGS.tpu_name is None:
raise RuntimeError("You must specify either --master or --tpu_name.")
if FLAGS.master is not None:
if FLAGS.tpu_name is not None:
tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
"--tpu_name and using --master.")
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
run_config = tf.contrib.tpu.RunConfig( run_config = tf.contrib.tpu.RunConfig(
master=FLAGS.master, master=tpu_grpc_url,
evaluation_master=FLAGS.master, evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto( session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True), allow_soft_placement=True, log_device_placement=True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment