Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
497989e0
Commit
497989e0
authored
Sep 24, 2019
by
Bruce Fontaine
Committed by
A. Unique TensorFlower
Sep 24, 2019
Browse files
Use experimental_connect_to_cluster API in TPU lib to support training on a slice of a TPU pod.
PiperOrigin-RevId: 270926016
parent
a52564cb
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
345 additions
and
403 deletions
+345
-403
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+218
-218
official/modeling/model_training_utils_test.py
official/modeling/model_training_utils_test.py
+0
-1
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+2
-8
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+3
-8
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+0
-3
official/nlp/xlnet/run_classifier.py
official/nlp/xlnet/run_classifier.py
+16
-27
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+16
-27
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+15
-26
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+72
-77
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+0
-1
official/utils/misc/tpu_lib.py
official/utils/misc/tpu_lib.py
+3
-7
No files found.
official/modeling/model_training_utils.py
View file @
497989e0
...
@@ -130,8 +130,7 @@ def run_customized_training_loop(
...
@@ -130,8 +130,7 @@ def run_customized_training_loop(
after every epoch.
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
`model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
use_remote_tpu: Ignored, will be removed in the future.
as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
methods are invoked during training.
...
@@ -146,6 +145,8 @@ def run_customized_training_loop(
...
@@ -146,6 +145,8 @@ def run_customized_training_loop(
attribute or when required parameters are set to none. (2) eval args are
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
not specified correctly. (3) metric_fn must be a callable if specified.
"""
"""
# TODO(bfontain): Remove use_remote_tpu once there are no models using it.
del
use_remote_tpu
if
_sentinel
is
not
None
:
if
_sentinel
is
not
None
:
raise
ValueError
(
'only call `run_customized_training_loop()` '
raise
ValueError
(
'only call `run_customized_training_loop()` '
...
@@ -188,7 +189,6 @@ def run_customized_training_loop(
...
@@ -188,7 +189,6 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
# pipeline ops in worker task.
with
tf
.
device
(
tpu_lib
.
get_primary_cpu_task
(
use_remote_tpu
)):
train_iterator
=
_get_input_iterator
(
train_input_fn
,
strategy
)
train_iterator
=
_get_input_iterator
(
train_input_fn
,
strategy
)
with
distribution_utils
.
get_strategy_scope
(
strategy
):
with
distribution_utils
.
get_strategy_scope
(
strategy
):
...
...
official/modeling/model_training_utils_test.py
View file @
497989e0
...
@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_steps
=
10
,
eval_steps
=
10
,
init_checkpoint
=
None
,
init_checkpoint
=
None
,
metric_fn
=
metric_fn
,
metric_fn
=
metric_fn
,
use_remote_tpu
=
False
,
custom_callbacks
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
run_eagerly
)
run_eagerly
=
run_eagerly
)
...
...
official/nlp/bert/run_classifier.py
View file @
497989e0
...
@@ -90,7 +90,6 @@ def run_customized_training(strategy,
...
@@ -90,7 +90,6 @@ def run_customized_training(strategy,
warmup_steps
,
warmup_steps
,
initial_lr
,
initial_lr
,
init_checkpoint
,
init_checkpoint
,
use_remote_tpu
=
False
,
custom_callbacks
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
False
):
run_eagerly
=
False
):
"""Run BERT classifier training using low-level API."""
"""Run BERT classifier training using low-level API."""
...
@@ -151,7 +150,6 @@ def run_customized_training(strategy,
...
@@ -151,7 +150,6 @@ def run_customized_training(strategy,
eval_steps
=
eval_steps
,
eval_steps
=
eval_steps
,
init_checkpoint
=
init_checkpoint
,
init_checkpoint
=
init_checkpoint
,
metric_fn
=
metric_fn
,
metric_fn
=
metric_fn
,
use_remote_tpu
=
use_remote_tpu
,
custom_callbacks
=
custom_callbacks
,
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
run_eagerly
)
run_eagerly
=
run_eagerly
)
...
@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
...
@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
# Runs customized training loop.
# Runs customized training loop.
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
'strategy.'
)
'strategy.'
)
use_remote_tpu
=
(
FLAGS
.
strategy_type
==
'tpu'
and
FLAGS
.
tpu
)
trained_model
=
run_customized_training
(
trained_model
=
run_customized_training
(
strategy
,
strategy
,
bert_config
,
bert_config
,
...
@@ -214,11 +211,9 @@ def run_bert(strategy, input_meta_data):
...
@@ -214,11 +211,9 @@ def run_bert(strategy, input_meta_data):
warmup_steps
,
warmup_steps
,
FLAGS
.
learning_rate
,
FLAGS
.
learning_rate
,
FLAGS
.
init_checkpoint
,
FLAGS
.
init_checkpoint
,
use_remote_tpu
=
use_remote_tpu
,
run_eagerly
=
FLAGS
.
run_eagerly
)
run_eagerly
=
FLAGS
.
run_eagerly
)
if
FLAGS
.
model_export_path
:
if
FLAGS
.
model_export_path
:
with
tf
.
device
(
tpu_lib
.
get_primary_cpu_task
(
use_remote_tpu
)):
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
FLAGS
.
model_export_path
,
model
=
trained_model
)
FLAGS
.
model_export_path
,
model
=
trained_model
)
return
trained_model
return
trained_model
...
@@ -238,7 +233,6 @@ def main(_):
...
@@ -238,7 +233,6 @@ def main(_):
if
FLAGS
.
strategy_type
==
'mirror'
:
if
FLAGS
.
strategy_type
==
'mirror'
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
'tpu'
:
elif
FLAGS
.
strategy_type
==
'tpu'
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
else
:
...
...
official/nlp/bert/run_pretraining.py
View file @
497989e0
...
@@ -114,8 +114,7 @@ def run_customized_training(strategy,
...
@@ -114,8 +114,7 @@ def run_customized_training(strategy,
initial_lr
,
initial_lr
,
warmup_steps
,
warmup_steps
,
input_files
,
input_files
,
train_batch_size
,
train_batch_size
):
use_remote_tpu
=
False
):
"""Run BERT pretrain model training using low-level API."""
"""Run BERT pretrain model training using low-level API."""
train_input_fn
=
functools
.
partial
(
get_pretrain_input_data
,
input_files
,
train_input_fn
=
functools
.
partial
(
get_pretrain_input_data
,
input_files
,
...
@@ -148,8 +147,7 @@ def run_customized_training(strategy,
...
@@ -148,8 +147,7 @@ def run_customized_training(strategy,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_loop
=
steps_per_loop
,
steps_per_loop
=
steps_per_loop
,
epochs
=
epochs
,
epochs
=
epochs
)
use_remote_tpu
=
use_remote_tpu
)
# Creates the BERT core model outside distribution strategy scope.
# Creates the BERT core model outside distribution strategy scope.
_
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
_
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
...
@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
...
@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
'strategy.'
)
'strategy.'
)
use_remote_tpu
=
(
FLAGS
.
strategy_type
==
'tpu'
and
FLAGS
.
tpu
)
return
run_customized_training
(
return
run_customized_training
(
strategy
,
strategy
,
bert_config
,
bert_config
,
...
@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
...
@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
FLAGS
.
learning_rate
,
FLAGS
.
learning_rate
,
FLAGS
.
warmup_steps
,
FLAGS
.
warmup_steps
,
FLAGS
.
input_files
,
FLAGS
.
input_files
,
FLAGS
.
train_batch_size
,
FLAGS
.
train_batch_size
)
use_remote_tpu
=
use_remote_tpu
)
def
main
(
_
):
def
main
(
_
):
...
@@ -200,7 +196,6 @@ def main(_):
...
@@ -200,7 +196,6 @@ def main(_):
if
FLAGS
.
strategy_type
==
'mirror'
:
if
FLAGS
.
strategy_type
==
'mirror'
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
'tpu'
:
elif
FLAGS
.
strategy_type
==
'tpu'
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
else
:
...
...
official/nlp/bert/run_squad.py
View file @
497989e0
...
@@ -245,7 +245,6 @@ def train_squad(strategy,
...
@@ -245,7 +245,6 @@ def train_squad(strategy,
loss_fn
=
get_loss_fn
(
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
/
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
use_remote_tpu
=
(
FLAGS
.
strategy_type
==
'tpu'
and
FLAGS
.
tpu
)
model_training_utils
.
run_customized_training_loop
(
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
strategy
=
strategy
,
...
@@ -257,7 +256,6 @@ def train_squad(strategy,
...
@@ -257,7 +256,6 @@ def train_squad(strategy,
epochs
=
epochs
,
epochs
=
epochs
,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
use_remote_tpu
=
use_remote_tpu
,
run_eagerly
=
run_eagerly
,
run_eagerly
=
run_eagerly
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
)
...
@@ -366,7 +364,6 @@ def main(_):
...
@@ -366,7 +364,6 @@ def main(_):
elif
FLAGS
.
strategy_type
==
'multi_worker_mirror'
:
elif
FLAGS
.
strategy_type
==
'multi_worker_mirror'
:
strategy
=
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
()
strategy
=
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
()
elif
FLAGS
.
strategy_type
==
'tpu'
:
elif
FLAGS
.
strategy_type
==
'tpu'
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
else
:
...
...
official/nlp/xlnet/run_classifier.py
View file @
497989e0
...
@@ -126,23 +126,13 @@ def get_metric_fn():
...
@@ -126,23 +126,13 @@ def get_metric_fn():
return
train_acc_metric
return
train_acc_metric
def
get_primary_cpu_task
(
use_remote_tpu
=
False
):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return
"/job:worker"
if
use_remote_tpu
else
""
def
main
(
unused_argv
):
def
main
(
unused_argv
):
del
unused_argv
del
unused_argv
use_remote_tpu
=
False
if
FLAGS
.
strategy_type
==
"mirror"
:
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
elif
FLAGS
.
strategy_type
==
"tpu"
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
use_remote_tpu
=
True
else
:
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
FLAGS
.
strategy_type
)
...
@@ -180,7 +170,6 @@ def main(unused_argv):
...
@@ -180,7 +170,6 @@ def main(unused_argv):
input_meta_data
[
"lr_layer_decay_rate"
]
=
FLAGS
.
lr_layer_decay_rate
input_meta_data
[
"lr_layer_decay_rate"
]
=
FLAGS
.
lr_layer_decay_rate
input_meta_data
[
"n_class"
]
=
FLAGS
.
n_class
input_meta_data
[
"n_class"
]
=
FLAGS
.
n_class
with
tf
.
device
(
get_primary_cpu_task
(
use_remote_tpu
)):
training_utils
.
train
(
training_utils
.
train
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
...
...
official/nlp/xlnet/run_pretrain.py
View file @
497989e0
...
@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
...
@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
return
model
return
model
def
get_primary_cpu_task
(
use_remote_tpu
=
False
):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return
"/job:worker"
if
use_remote_tpu
else
""
def
main
(
unused_argv
):
def
main
(
unused_argv
):
del
unused_argv
del
unused_argv
use_remote_tpu
=
False
num_hosts
=
1
num_hosts
=
1
if
FLAGS
.
strategy_type
==
"mirror"
:
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
elif
FLAGS
.
strategy_type
==
"tpu"
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
use_remote_tpu
=
True
topology
=
FLAGS
.
tpu_topology
.
split
(
"x"
)
topology
=
FLAGS
.
tpu_topology
.
split
(
"x"
)
total_num_core
=
2
*
int
(
topology
[
0
])
*
int
(
topology
[
1
])
total_num_core
=
2
*
int
(
topology
[
0
])
*
int
(
topology
[
1
])
num_hosts
=
total_num_core
//
FLAGS
.
num_core_per_host
num_hosts
=
total_num_core
//
FLAGS
.
num_core_per_host
...
@@ -111,7 +101,6 @@ def main(unused_argv):
...
@@ -111,7 +101,6 @@ def main(unused_argv):
model_fn
=
functools
.
partial
(
get_pretrainxlnet_model
,
model_config
,
model_fn
=
functools
.
partial
(
get_pretrainxlnet_model
,
model_config
,
run_config
)
run_config
)
with
tf
.
device
(
get_primary_cpu_task
(
use_remote_tpu
)):
training_utils
.
train
(
training_utils
.
train
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
...
...
official/nlp/xlnet/run_squad.py
View file @
497989e0
...
@@ -91,13 +91,6 @@ class InputFeatures(object):
...
@@ -91,13 +91,6 @@ class InputFeatures(object):
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
def
get_primary_cpu_task
(
use_remote_tpu
=
False
):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return
"/job:worker"
if
use_remote_tpu
else
""
# pylint: disable=unused-argument
# pylint: disable=unused-argument
def
run_evaluation
(
strategy
,
def
run_evaluation
(
strategy
,
test_input_fn
,
test_input_fn
,
...
@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
...
@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def
main
(
unused_argv
):
def
main
(
unused_argv
):
del
unused_argv
del
unused_argv
use_remote_tpu
=
False
if
FLAGS
.
strategy_type
==
"mirror"
:
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
elif
FLAGS
.
strategy_type
==
"tpu"
:
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
use_remote_tpu
=
True
else
:
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
FLAGS
.
strategy_type
)
...
@@ -285,7 +275,6 @@ def main(unused_argv):
...
@@ -285,7 +275,6 @@ def main(unused_argv):
eval_fn
=
functools
.
partial
(
run_evaluation
,
strategy
,
test_input_fn
,
eval_fn
=
functools
.
partial
(
run_evaluation
,
strategy
,
test_input_fn
,
eval_steps
,
input_meta_data
)
eval_steps
,
input_meta_data
)
with
tf
.
device
(
get_primary_cpu_task
(
use_remote_tpu
)):
training_utils
.
train
(
training_utils
.
train
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
...
...
official/recommendation/ncf_keras_main.py
View file @
497989e0
...
@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
...
@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
tpu_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -254,10 +253,6 @@ def run_ncf(_):
...
@@ -254,10 +253,6 @@ def run_ncf(_):
"val_HR_METRIC"
,
desired_value
=
FLAGS
.
hr_threshold
)
"val_HR_METRIC"
,
desired_value
=
FLAGS
.
hr_threshold
)
callbacks
.
append
(
early_stopping_callback
)
callbacks
.
append
(
early_stopping_callback
)
use_remote_tpu
=
params
[
"use_tpu"
]
and
FLAGS
.
tpu
primary_cpu_task
=
tpu_lib
.
get_primary_cpu_task
(
use_remote_tpu
)
with
tf
.
device
(
primary_cpu_task
):
(
train_input_dataset
,
eval_input_dataset
,
(
train_input_dataset
,
eval_input_dataset
,
num_train_steps
,
num_eval_steps
)
=
\
num_train_steps
,
num_eval_steps
)
=
\
(
ncf_input_pipeline
.
create_ncf_input_data
(
(
ncf_input_pipeline
.
create_ncf_input_data
(
...
...
official/utils/misc/distribution_utils.py
View file @
497989e0
...
@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
...
@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
if
distribution_strategy
==
"tpu"
:
if
distribution_strategy
==
"tpu"
:
# When tpu_address is an empty string, we communicate with local TPUs.
# When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
tpu_address
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
tpu_address
)
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
...
...
official/utils/misc/tpu_lib.py
View file @
497989e0
...
@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
...
@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
"""Initializes TPU for TF 2.0 training.
Args:
Args:
tpu_address: string, bns address of TPU worker
s
.
tpu_address: string, bns address of
master
TPU worker.
Returns:
Returns:
A TPUClusterResolver.
A TPUClusterResolver.
"""
"""
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
tpu
=
tpu_address
)
tpu
=
tpu_address
)
tf
.
config
.
experimental_connect_to_host
(
cluster_resolver
.
master
())
if
tpu_address
not
in
(
''
,
'local'
):
tf
.
config
.
experimental_connect_to_cluster
(
cluster_resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
cluster_resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
cluster_resolver
)
return
cluster_resolver
return
cluster_resolver
def
get_primary_cpu_task
(
use_remote_tpu
=
False
):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return
"/job:worker"
if
use_remote_tpu
else
""
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment