"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b00bacf751433701473b596a3e429b515109a1f8"
Unverified Commit ea7481c8 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #3510 from tensorflow/tpu_mnist_cluster_resolver

Upgrade mnist_tpu.py to use the new TPUClusterResolver.
parents 6a84aa6e 9c03af08
...@@ -27,19 +27,22 @@ import tensorflow as tf ...@@ -27,19 +27,22 @@ import tensorflow as tf
import dataset import dataset
import mnist import mnist
# Cloud TPU Cluster Resolvers # Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string( tf.flags.DEFINE_string(
"gcp_project", default=None, "tpu", default=None,
help="Project name for the Cloud TPU-enabled project. If not specified, we " help="The Cloud TPU to use for training. This should be either the name "
"will attempt to automatically detect the GCE project from metadata.") "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string( tf.flags.DEFINE_string(
"tpu_zone", default=None, "tpu_zone", default=None,
help="GCE zone where the Cloud TPU is located in. If not specified, we " help="[Optional] GCE zone where the Cloud TPU is located in. If not "
"will attempt to automatically detect the GCE project from metadata.") "specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string( tf.flags.DEFINE_string(
"tpu_name", default=None, "gcp_project", default=None,
help="Name of the Cloud TPU for Cluster Resolvers. You must specify either " help="[Optional] Project name for the Cloud TPU-enabled project. If not "
"this flag or --master.") "specified, we will attempt to automatically detect the GCE project from "
"metadata.")
# Model specific parameters # Model specific parameters
tf.flags.DEFINE_string( tf.flags.DEFINE_string(
...@@ -74,6 +77,8 @@ def metric_fn(labels, logits): ...@@ -74,6 +77,8 @@ def metric_fn(labels, logits):
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits."""
del params del params
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf.estimator.ModeKeys.PREDICT:
raise RuntimeError("mode {} is not supported yet".format(mode)) raise RuntimeError("mode {} is not supported yet".format(mode))
...@@ -105,6 +110,7 @@ def model_fn(features, labels, mode, params): ...@@ -105,6 +110,7 @@ def model_fn(features, labels, mode, params):
def train_input_fn(params): def train_input_fn(params):
"""train_input_fn defines the input pipeline used for training."""
batch_size = params["batch_size"] batch_size = params["batch_size"]
data_dir = params["data_dir"] data_dir = params["data_dir"]
# Retrieves the batch size for the current shard. The # of shards is # Retrieves the batch size for the current shard. The # of shards is
...@@ -130,25 +136,11 @@ def main(argv): ...@@ -130,25 +136,11 @@ def main(argv):
del argv # Unused. del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.master is None and FLAGS.tpu_name is None: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
raise RuntimeError("You must specify either --master or --tpu_name.") FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
if FLAGS.master is not None:
if FLAGS.tpu_name is not None:
tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
"--tpu_name and using --master.")
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
run_config = tf.contrib.tpu.RunConfig( run_config = tf.contrib.tpu.RunConfig(
master=tpu_grpc_url, cluster=tpu_cluster_resolver,
evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto( session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True), allow_soft_placement=True, log_device_placement=True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment