Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7bf81db8
Commit
7bf81db8
authored
Mar 11, 2020
by
Jose Baiocchi
Committed by
A. Unique TensorFlower
Mar 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 300399639
parent
1fdfd973
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
official/r1/mnist/mnist_tpu.py
official/r1/mnist/mnist_tpu.py
+10
-10
No files found.
official/r1/mnist/mnist_tpu.py
View file @
7bf81db8
...
@@ -28,7 +28,7 @@ import sys
...
@@ -28,7 +28,7 @@ import sys
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
absl
import
app
as
absl_app
# pylint: disable=unused-import
from
absl
import
app
as
absl_app
# pylint: disable=unused-import
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
# For open source environment, add grandparent directory for import
# For open source environment, add grandparent directory for import
...
@@ -98,7 +98,7 @@ def model_fn(features, labels, mode, params):
...
@@ -98,7 +98,7 @@ def model_fn(features, labels, mode, params):
'class_ids'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'class_ids'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
),
}
}
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
,
predictions
=
predictions
)
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
,
predictions
=
predictions
)
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
...
@@ -111,14 +111,14 @@ def model_fn(features, labels, mode, params):
...
@@ -111,14 +111,14 @@ def model_fn(features, labels, mode, params):
decay_rate
=
0.96
)
decay_rate
=
0.96
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
learning_rate
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
learning_rate
)
if
FLAGS
.
use_tpu
:
if
FLAGS
.
use_tpu
:
optimizer
=
tf
.
compat
.
v1
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
tf
.
tpu
.
CrossShardOptimizer
(
optimizer
)
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
loss
,
loss
=
loss
,
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
train
.
get_global_step
()))
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
train
.
get_global_step
()))
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
(
metric_fn
,
[
labels
,
logits
]))
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
(
metric_fn
,
[
labels
,
logits
]))
...
@@ -128,7 +128,7 @@ def train_input_fn(params):
...
@@ -128,7 +128,7 @@ def train_input_fn(params):
data_dir
=
params
[
"data_dir"
]
data_dir
=
params
[
"data_dir"
]
# Retrieves the batch size for the current shard. The # of shards is
# Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See
# computed according to the input pipeline deployment. See
# `tf.
compat.v1.
estimator.tpu.RunConfig` for details.
# `tf.estimator.tpu.RunConfig` for details.
ds
=
dataset
.
train
(
data_dir
).
cache
().
repeat
().
shuffle
(
ds
=
dataset
.
train
(
data_dir
).
cache
().
repeat
().
shuffle
(
buffer_size
=
50000
).
batch
(
batch_size
,
drop_remainder
=
True
)
buffer_size
=
50000
).
batch
(
batch_size
,
drop_remainder
=
True
)
return
ds
return
ds
...
@@ -159,16 +159,15 @@ def main(argv):
...
@@ -159,16 +159,15 @@ def main(argv):
project
=
FLAGS
.
gcp_project
project
=
FLAGS
.
gcp_project
)
)
run_config
=
tf
.
compat
.
v1
.
estimator
.
tpu
.
RunConfig
(
run_config
=
tf
.
estimator
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
model_dir
=
FLAGS
.
model_dir
,
model_dir
=
FLAGS
.
model_dir
,
session_config
=
tf
.
ConfigProto
(
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
True
),
allow_soft_placement
=
True
,
log_device_placement
=
True
),
tpu_config
=
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUConfig
(
tpu_config
=
tf
.
estimator
.
tpu
.
TPUConfig
(
FLAGS
.
iterations
,
FLAGS
.
num_shards
),
FLAGS
.
iterations
,
FLAGS
.
num_shards
),
)
)
estimator
=
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimator
(
estimator
=
tf
.
estimator
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_fn
=
model_fn
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
train_batch_size
=
FLAGS
.
batch_size
,
train_batch_size
=
FLAGS
.
batch_size
,
...
@@ -199,4 +198,5 @@ def main(argv):
...
@@ -199,4 +198,5 @@ def main(argv):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
disable_v2_behavior
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment