Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ea7481c8
Unverified
Commit
ea7481c8
authored
Mar 02, 2018
by
Neal Wu
Committed by
GitHub
Mar 02, 2018
Browse files
Merge pull request #3510 from tensorflow/tpu_mnist_cluster_resolver
Upgrade mnist_tpu.py to use the new TPUClusterResolver.
parents
6a84aa6e
9c03af08
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
26 deletions
+18
-26
official/mnist/mnist_tpu.py
official/mnist/mnist_tpu.py
+18
-26
No files found.
official/mnist/mnist_tpu.py
View file @
ea7481c8
...
...
@@ -27,19 +27,22 @@ import tensorflow as tf
import
dataset
import
mnist
# Cloud TPU Cluster Resolvers
# Cloud TPU Cluster Resolver
flag
s
tf
.
flags
.
DEFINE_string
(
"gcp_project"
,
default
=
None
,
help
=
"Project name for the Cloud TPU-enabled project. If not specified, we "
"will attempt to automatically detect the GCE project from metadata."
)
"tpu"
,
default
=
None
,
help
=
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url."
)
tf
.
flags
.
DEFINE_string
(
"tpu_zone"
,
default
=
None
,
help
=
"GCE zone where the Cloud TPU is located in. If not specified, we "
"will attempt to automatically detect the GCE project from metadata."
)
help
=
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata."
)
tf
.
flags
.
DEFINE_string
(
"tpu_name"
,
default
=
None
,
help
=
"Name of the Cloud TPU for Cluster Resolvers. You must specify either "
"this flag or --master."
)
"gcp_project"
,
default
=
None
,
help
=
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata."
)
# Model specific parameters
tf
.
flags
.
DEFINE_string
(
...
...
@@ -74,6 +77,8 @@ def metric_fn(labels, logits):
def
model_fn
(
features
,
labels
,
mode
,
params
):
"""model_fn constructs the ML model used to predict handwritten digits."""
del
params
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
raise
RuntimeError
(
"mode {} is not supported yet"
.
format
(
mode
))
...
...
@@ -105,6 +110,7 @@ def model_fn(features, labels, mode, params):
def
train_input_fn
(
params
):
"""train_input_fn defines the input pipeline used for training."""
batch_size
=
params
[
"batch_size"
]
data_dir
=
params
[
"data_dir"
]
# Retrieves the batch size for the current shard. The # of shards is
...
...
@@ -130,25 +136,11 @@ def main(argv):
del
argv
# Unused.
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
if
FLAGS
.
master
is
None
and
FLAGS
.
tpu_name
is
None
:
raise
RuntimeError
(
"You must specify either --master or --tpu_name."
)
if
FLAGS
.
master
is
not
None
:
if
FLAGS
.
tpu_name
is
not
None
:
tf
.
logging
.
warn
(
"Both --master and --tpu_name are set. Ignoring "
"--tpu_name and using --master."
)
tpu_grpc_url
=
FLAGS
.
master
else
:
tpu_cluster_resolver
=
(
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_names
=
[
FLAGS
.
tpu_name
],
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
))
tpu_grpc_url
=
tpu_cluster_resolver
.
get_master
()
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
))
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
master
=
tpu_grpc_url
,
evaluation_master
=
tpu_grpc_url
,
cluster
=
tpu_cluster_resolver
,
model_dir
=
FLAGS
.
model_dir
,
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
True
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment