Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
357f30f4
Commit
357f30f4
authored
Dec 13, 2019
by
A. Unique TensorFlower
Browse files
Clearly demarcate contrib symbols from standard tf symbols by importing them directly.
PiperOrigin-RevId: 285503670
parent
c71043da
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
62 additions
and
51 deletions
+62
-51
official/r1/boosted_trees/train_higgs.py
official/r1/boosted_trees/train_higgs.py
+2
-1
official/r1/mnist/mnist_eager.py
official/r1/mnist/mnist_eager.py
+9
-8
official/r1/mnist/mnist_tpu.py
official/r1/mnist/mnist_tpu.py
+11
-12
official/r1/resnet/resnet_run_loop.py
official/r1/resnet/resnet_run_loop.py
+2
-1
official/r1/utils/tpu.py
official/r1/utils/tpu.py
+5
-4
official/recommendation/neumf_model.py
official/recommendation/neumf_model.py
+3
-2
official/transformer/transformer_main.py
official/transformer/transformer_main.py
+23
-16
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+7
-7
No files found.
official/r1/boosted_trees/train_higgs.py
View file @
357f30f4
...
@@ -53,6 +53,7 @@ import tensorflow as tf
...
@@ -53,6 +53,7 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
tensorflow.contrib
import
estimator
as
contrib_estimator
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -229,7 +230,7 @@ def train_boosted_trees(flags_obj):
...
@@ -229,7 +230,7 @@ def train_boosted_trees(flags_obj):
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
# training is yet provided as a contrib library.
classifier
=
tf
.
contrib
.
estimator
.
boosted_trees_classifier_train_in_memory
(
classifier
=
contrib
_
estimator
.
boosted_trees_classifier_train_in_memory
(
train_input_fn
,
train_input_fn
,
feature_columns
,
feature_columns
,
model_dir
=
flags_obj
.
model_dir
or
None
,
model_dir
=
flags_obj
.
model_dir
or
None
,
...
...
official/r1/mnist/mnist_eager.py
View file @
357f30f4
...
@@ -33,6 +33,7 @@ import time
...
@@ -33,6 +33,7 @@ import time
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
summary
as
contrib_summary
from
tensorflow.python
import
eager
as
tfe
from
tensorflow.python
import
eager
as
tfe
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
...
@@ -61,7 +62,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
...
@@ -61,7 +62,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
start
=
time
.
time
()
start
=
time
.
time
()
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
dataset
):
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
dataset
):
with
tf
.
contrib
.
summary
.
record_summaries_every_n_global_steps
(
with
contrib
_
summary
.
record_summaries_every_n_global_steps
(
10
,
global_step
=
step_counter
):
10
,
global_step
=
step_counter
):
# Record the operations used to compute the loss given the input,
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# so that the gradient of the loss with respect to the variables
...
@@ -69,8 +70,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
...
@@ -69,8 +70,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
with
tf
.
GradientTape
()
as
tape
:
with
tf
.
GradientTape
()
as
tape
:
logits
=
model
(
images
,
training
=
True
)
logits
=
model
(
images
,
training
=
True
)
loss_value
=
loss
(
logits
,
labels
)
loss_value
=
loss
(
logits
,
labels
)
tf
.
contrib
.
summary
.
scalar
(
'loss'
,
loss_value
)
contrib
_
summary
.
scalar
(
'loss'
,
loss_value
)
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
compute_accuracy
(
logits
,
labels
))
contrib
_
summary
.
scalar
(
'accuracy'
,
compute_accuracy
(
logits
,
labels
))
grads
=
tape
.
gradient
(
loss_value
,
model
.
variables
)
grads
=
tape
.
gradient
(
loss_value
,
model
.
variables
)
optimizer
.
apply_gradients
(
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
variables
),
global_step
=
step_counter
)
zip
(
grads
,
model
.
variables
),
global_step
=
step_counter
)
...
@@ -93,9 +94,9 @@ def test(model, dataset):
...
@@ -93,9 +94,9 @@ def test(model, dataset):
tf
.
cast
(
labels
,
tf
.
int64
))
tf
.
cast
(
labels
,
tf
.
int64
))
print
(
'Test set: Average loss: %.4f, Accuracy: %4f%%
\n
'
%
print
(
'Test set: Average loss: %.4f, Accuracy: %4f%%
\n
'
%
(
avg_loss
.
result
(),
100
*
accuracy
.
result
()))
(
avg_loss
.
result
(),
100
*
accuracy
.
result
()))
with
tf
.
contrib
.
summary
.
always_record_summaries
():
with
contrib
_
summary
.
always_record_summaries
():
tf
.
contrib
.
summary
.
scalar
(
'loss'
,
avg_loss
.
result
())
contrib
_
summary
.
scalar
(
'loss'
,
avg_loss
.
result
())
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
accuracy
.
result
())
contrib
_
summary
.
scalar
(
'accuracy'
,
accuracy
.
result
())
def
run_mnist_eager
(
flags_obj
):
def
run_mnist_eager
(
flags_obj
):
...
@@ -137,9 +138,9 @@ def run_mnist_eager(flags_obj):
...
@@ -137,9 +138,9 @@ def run_mnist_eager(flags_obj):
else
:
else
:
train_dir
=
None
train_dir
=
None
test_dir
=
None
test_dir
=
None
summary_writer
=
tf
.
contrib
.
summary
.
create_file_writer
(
summary_writer
=
contrib
_
summary
.
create_file_writer
(
train_dir
,
flush_millis
=
10000
)
train_dir
,
flush_millis
=
10000
)
test_summary_writer
=
tf
.
contrib
.
summary
.
create_file_writer
(
test_summary_writer
=
contrib
_
summary
.
create_file_writer
(
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
# Create and restore checkpoint (if one exists on the path)
# Create and restore checkpoint (if one exists on the path)
...
...
official/r1/mnist/mnist_tpu.py
View file @
357f30f4
...
@@ -33,6 +33,8 @@ import tensorflow as tf
...
@@ -33,6 +33,8 @@ import tensorflow as tf
# For open source environment, add grandparent directory for import
# For open source environment, add grandparent directory for import
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
sys
.
path
[
0
]))))
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
sys
.
path
[
0
]))))
from
tensorflow.contrib
import
cluster_resolver
as
contrib_cluster_resolver
from
tensorflow.contrib
import
tpu
as
contrib_tpu
from
official.r1.mnist
import
dataset
# pylint: disable=wrong-import-position
from
official.r1.mnist
import
dataset
# pylint: disable=wrong-import-position
from
official.r1.mnist
import
mnist
# pylint: disable=wrong-import-position
from
official.r1.mnist
import
mnist
# pylint: disable=wrong-import-position
...
@@ -98,7 +100,7 @@ def model_fn(features, labels, mode, params):
...
@@ -98,7 +100,7 @@ def model_fn(features, labels, mode, params):
'class_ids'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'class_ids'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
),
}
}
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
,
predictions
=
predictions
)
return
contrib
_
tpu
.
TPUEstimatorSpec
(
mode
,
predictions
=
predictions
)
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
...
@@ -111,14 +113,14 @@ def model_fn(features, labels, mode, params):
...
@@ -111,14 +113,14 @@ def model_fn(features, labels, mode, params):
decay_rate
=
0.96
)
decay_rate
=
0.96
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
learning_rate
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
learning_rate
)
if
FLAGS
.
use_tpu
:
if
FLAGS
.
use_tpu
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
contrib
_
tpu
.
CrossShardOptimizer
(
optimizer
)
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
contrib
_
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
loss
,
loss
=
loss
,
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
train
.
get_global_step
()))
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
train
.
get_global_step
()))
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
contrib
_
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
(
metric_fn
,
[
labels
,
logits
]))
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
(
metric_fn
,
[
labels
,
logits
]))
...
@@ -153,21 +155,18 @@ def main(argv):
...
@@ -153,21 +155,18 @@ def main(argv):
del
argv
# Unused.
del
argv
# Unused.
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_cluster_resolver
=
contrib_cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu
,
FLAGS
.
tpu
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
contrib
_
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
model_dir
=
FLAGS
.
model_dir
,
model_dir
=
FLAGS
.
model_dir
,
session_config
=
tf
.
ConfigProto
(
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
True
),
allow_soft_placement
=
True
,
log_device_placement
=
True
),
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
FLAGS
.
iterations
,
FLAGS
.
num_shards
),
tpu_config
=
contrib
_
tpu
.
TPUConfig
(
FLAGS
.
iterations
,
FLAGS
.
num_shards
),
)
)
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
estimator
=
contrib
_
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_fn
=
model_fn
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
train_batch_size
=
FLAGS
.
batch_size
,
train_batch_size
=
FLAGS
.
batch_size
,
...
...
official/r1/resnet/resnet_run_loop.py
View file @
357f30f4
...
@@ -30,6 +30,7 @@ import os
...
@@ -30,6 +30,7 @@ import os
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
opt
as
contrib_opt
from
official.r1.resnet
import
imagenet_preprocessing
from
official.r1.resnet
import
imagenet_preprocessing
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_model
...
@@ -445,7 +446,7 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -445,7 +446,7 @@ def resnet_model_fn(features, labels, mode, model_class,
tf
.
compat
.
v1
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
tf
.
compat
.
v1
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
if
flags
.
FLAGS
.
enable_lars
:
if
flags
.
FLAGS
.
enable_lars
:
optimizer
=
tf
.
contrib
.
opt
.
LARSOptimizer
(
optimizer
=
contrib
_
opt
.
LARSOptimizer
(
learning_rate
,
learning_rate
,
momentum
=
momentum
,
momentum
=
momentum
,
weight_decay
=
weight_decay
,
weight_decay
=
weight_decay
,
...
...
official/r1/utils/tpu.py
View file @
357f30f4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Functions specific to running TensorFlow on TPUs."""
"""Functions specific to running TensorFlow on TPUs."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
summary
as
contrib_summary
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
...
@@ -58,13 +59,13 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
...
@@ -58,13 +59,13 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
List of summary ops to run on the CPU host.
List of summary ops to run on the CPU host.
"""
"""
step
=
global_step
[
0
]
step
=
global_step
[
0
]
with
tf
.
contrib
.
summary
.
create_file_writer
(
with
contrib
_
summary
.
create_file_writer
(
logdir
=
model_dir
,
filename_suffix
=
".host_call"
).
as_default
():
logdir
=
model_dir
,
filename_suffix
=
".host_call"
).
as_default
():
with
tf
.
contrib
.
summary
.
always_record_summaries
():
with
contrib
_
summary
.
always_record_summaries
():
for
i
,
name
in
enumerate
(
metric_names
):
for
i
,
name
in
enumerate
(
metric_names
):
tf
.
contrib
.
summary
.
scalar
(
prefix
+
name
,
args
[
i
][
0
],
step
=
step
)
contrib
_
summary
.
scalar
(
prefix
+
name
,
args
[
i
][
0
],
step
=
step
)
return
tf
.
contrib
.
summary
.
all_summary_ops
()
return
contrib
_
summary
.
all_summary_ops
()
# To log the current learning rate, and gradient norm for Tensorboard, the
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# summary op needs to be run on the host CPU via host_call. host_call
...
...
official/recommendation/neumf_model.py
View file @
357f30f4
...
@@ -37,6 +37,7 @@ import sys
...
@@ -37,6 +37,7 @@ import sys
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
tpu
as
contrib_tpu
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
movielens
...
@@ -116,7 +117,7 @@ def neumf_model_fn(features, labels, mode, params):
...
@@ -116,7 +117,7 @@ def neumf_model_fn(features, labels, mode, params):
epsilon
=
params
[
"epsilon"
])
epsilon
=
params
[
"epsilon"
])
if
params
[
"use_tpu"
]:
if
params
[
"use_tpu"
]:
# TODO(seemuch): remove this contrib import
# TODO(seemuch): remove this contrib import
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
contrib
_
tpu
.
CrossShardOptimizer
(
optimizer
)
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
MODEL_HP_LOSS_FN
,
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
MODEL_HP_LOSS_FN
,
value
=
mlperf_helper
.
TAGS
.
BCE
)
value
=
mlperf_helper
.
TAGS
.
BCE
)
...
@@ -274,7 +275,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
...
@@ -274,7 +275,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
use_tpu_spec
)
use_tpu_spec
)
if
use_tpu_spec
:
if
use_tpu_spec
:
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
contrib
_
tpu
.
TPUEstimatorSpec
(
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
loss
=
cross_entropy
,
loss
=
cross_entropy
,
eval_metrics
=
(
metric_fn
,
[
in_top_k
,
ndcg
,
metric_weights
]))
eval_metrics
=
(
metric_fn
,
[
in_top_k
,
ndcg
,
metric_weights
]))
...
...
official/transformer/transformer_main.py
View file @
357f30f4
...
@@ -33,6 +33,9 @@ import tensorflow as tf
...
@@ -33,6 +33,9 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.r1.utils
import
export
from
official.r1.utils
import
export
from
tensorflow.contrib
import
cluster_resolver
as
contrib_cluster_resolver
from
tensorflow.contrib
import
opt
as
contrib_opt
from
tensorflow.contrib
import
tpu
as
contrib_tpu
from
official.r1.utils
import
tpu
as
tpu_util
from
official.r1.utils
import
tpu
as
tpu_util
from
official.transformer
import
compute_bleu
from
official.transformer
import
compute_bleu
from
official.transformer
import
translate
from
official.transformer
import
translate
...
@@ -115,8 +118,10 @@ def model_fn(features, labels, mode, params):
...
@@ -115,8 +118,10 @@ def model_fn(features, labels, mode, params):
metric_fn
=
lambda
logits
,
labels
:
(
metric_fn
=
lambda
logits
,
labels
:
(
metrics
.
get_eval_metrics
(
logits
,
labels
,
params
=
params
))
metrics
.
get_eval_metrics
(
logits
,
labels
,
params
=
params
))
eval_metrics
=
(
metric_fn
,
[
logits
,
labels
])
eval_metrics
=
(
metric_fn
,
[
logits
,
labels
])
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
contrib_tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
predictions
=
{
"predictions"
:
logits
},
mode
=
mode
,
loss
=
loss
,
predictions
=
{
"predictions"
:
logits
},
eval_metrics
=
eval_metrics
)
eval_metrics
=
eval_metrics
)
return
tf
.
estimator
.
EstimatorSpec
(
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
predictions
=
{
"predictions"
:
logits
},
mode
=
mode
,
loss
=
loss
,
predictions
=
{
"predictions"
:
logits
},
...
@@ -128,12 +133,14 @@ def model_fn(features, labels, mode, params):
...
@@ -128,12 +133,14 @@ def model_fn(features, labels, mode, params):
# in TensorBoard.
# in TensorBoard.
metric_dict
[
"minibatch_loss"
]
=
loss
metric_dict
[
"minibatch_loss"
]
=
loss
if
params
[
"use_tpu"
]:
if
params
[
"use_tpu"
]:
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
contrib_tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
,
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
,
host_call
=
tpu_util
.
construct_scalar_host_call
(
host_call
=
tpu_util
.
construct_scalar_host_call
(
metric_dict
=
metric_dict
,
model_dir
=
params
[
"model_dir"
],
metric_dict
=
metric_dict
,
prefix
=
"training/"
)
model_dir
=
params
[
"model_dir"
],
)
prefix
=
"training/"
)
)
record_scalars
(
metric_dict
)
record_scalars
(
metric_dict
)
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
...
@@ -173,14 +180,14 @@ def get_train_op_and_metrics(loss, params):
...
@@ -173,14 +180,14 @@ def get_train_op_and_metrics(loss, params):
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
# than the TF core Adam optimizer.
# than the TF core Adam optimizer.
optimizer
=
tf
.
contrib
.
opt
.
LazyAdamOptimizer
(
optimizer
=
contrib
_
opt
.
LazyAdamOptimizer
(
learning_rate
,
learning_rate
,
beta1
=
params
[
"optimizer_adam_beta1"
],
beta1
=
params
[
"optimizer_adam_beta1"
],
beta2
=
params
[
"optimizer_adam_beta2"
],
beta2
=
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
epsilon
=
params
[
"optimizer_adam_epsilon"
])
if
params
[
"use_tpu"
]
and
params
[
"tpu"
]
!=
tpu_util
.
LOCAL
:
if
params
[
"use_tpu"
]
and
params
[
"tpu"
]
!=
tpu_util
.
LOCAL
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
contrib
_
tpu
.
CrossShardOptimizer
(
optimizer
)
# Uses automatic mixed precision FP16 training if on GPU.
# Uses automatic mixed precision FP16 training if on GPU.
if
params
[
"dtype"
]
==
"fp16"
:
if
params
[
"dtype"
]
==
"fp16"
:
...
@@ -528,31 +535,31 @@ def construct_estimator(flags_obj, params, schedule_manager):
...
@@ -528,31 +535,31 @@ def construct_estimator(flags_obj, params, schedule_manager):
model_fn
=
model_fn
,
model_dir
=
flags_obj
.
model_dir
,
params
=
params
,
model_fn
=
model_fn
,
model_dir
=
flags_obj
.
model_dir
,
params
=
params
,
config
=
tf
.
estimator
.
RunConfig
(
train_distribute
=
distribution_strategy
))
config
=
tf
.
estimator
.
RunConfig
(
train_distribute
=
distribution_strategy
))
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_cluster_resolver
=
contrib
_
cluster_resolver
.
TPUClusterResolver
(
tpu
=
flags_obj
.
tpu
,
tpu
=
flags_obj
.
tpu
,
zone
=
flags_obj
.
tpu_zone
,
zone
=
flags_obj
.
tpu_zone
,
project
=
flags_obj
.
tpu_gcp_project
project
=
flags_obj
.
tpu_gcp_project
)
)
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
tpu_config
=
contrib
_
tpu
.
TPUConfig
(
iterations_per_loop
=
schedule_manager
.
single_iteration_train_steps
,
iterations_per_loop
=
schedule_manager
.
single_iteration_train_steps
,
num_shards
=
flags_obj
.
num_tpu_shards
)
num_shards
=
flags_obj
.
num_tpu_shards
)
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
contrib
_
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
model_dir
=
flags_obj
.
model_dir
,
model_dir
=
flags_obj
.
model_dir
,
session_config
=
tf
.
ConfigProto
(
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
True
),
allow_soft_placement
=
True
,
log_device_placement
=
True
),
tpu_config
=
tpu_config
)
tpu_config
=
tpu_config
)
return
tf
.
contrib
.
tpu
.
TPUEstimator
(
return
contrib
_
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_fn
=
model_fn
,
use_tpu
=
params
[
"use_tpu"
]
and
flags_obj
.
tpu
!=
tpu_util
.
LOCAL
,
use_tpu
=
params
[
"use_tpu"
]
and
flags_obj
.
tpu
!=
tpu_util
.
LOCAL
,
train_batch_size
=
schedule_manager
.
batch_size
,
train_batch_size
=
schedule_manager
.
batch_size
,
eval_batch_size
=
schedule_manager
.
batch_size
,
eval_batch_size
=
schedule_manager
.
batch_size
,
params
=
{
params
=
{
# TPUEstimator needs to populate batch_size itself due to sharding.
# TPUEstimator needs to populate batch_size itself due to sharding.
key
:
value
for
key
,
value
in
params
.
items
()
if
key
!=
"batch_size"
},
key
:
value
for
key
,
value
in
params
.
items
()
if
key
!=
"batch_size"
},
config
=
run_config
)
config
=
run_config
)
...
...
official/utils/misc/distribution_utils.py
View file @
357f30f4
...
@@ -23,6 +23,7 @@ import os
...
@@ -23,6 +23,7 @@ import os
import
random
import
random
import
string
import
string
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
distribute
as
contrib_distribute
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
...
@@ -285,10 +286,9 @@ def set_up_synthetic_data():
...
@@ -285,10 +286,9 @@ def set_up_synthetic_data():
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
if
hasattr
(
tf
,
'contrib'
):
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
contrib_distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
contrib_distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
_monkey_patch_dataset_method
(
contrib_distribute
.
CollectiveAllReduceStrategy
)
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
else
:
else
:
print
(
'Contrib missing: Skip monkey patch tf.contrib.distribute.*'
)
print
(
'Contrib missing: Skip monkey patch tf.contrib.distribute.*'
)
...
@@ -300,10 +300,10 @@ def undo_set_up_synthetic_data():
...
@@ -300,10 +300,10 @@ def undo_set_up_synthetic_data():
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
if
hasattr
(
tf
,
'contrib'
):
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
contrib
_
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
contrib
_
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
contrib
_
distribute
.
CollectiveAllReduceStrategy
)
else
:
else
:
print
(
'Contrib missing: Skip remove monkey patch tf.contrib.distribute.*'
)
print
(
'Contrib missing: Skip remove monkey patch tf.contrib.distribute.*'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment