Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b6ece654
Commit
b6ece654
authored
Nov 16, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 16, 2020
Browse files
[Cleanup] Replace tf.distribute.experimental.TPUStrategy with tf.distribute.TPUStrategy
PiperOrigin-RevId: 342770296
parent
a9d13bf1
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
25 additions
and
19 deletions
+25
-19
official/common/distribute_utils.py
official/common/distribute_utils.py
+1
-1
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+3
-1
official/nlp/bert/model_training_utils_test.py
official/nlp/bert/model_training_utils_test.py
+3
-1
official/nlp/modeling/models/seq2seq_transformer_test.py
official/nlp/modeling/models/seq2seq_transformer_test.py
+3
-2
official/nlp/nhnet/input_pipeline.py
official/nlp/nhnet/input_pipeline.py
+1
-1
official/nlp/nhnet/models_test.py
official/nlp/nhnet/models_test.py
+6
-4
official/nlp/nhnet/trainer.py
official/nlp/nhnet/trainer.py
+1
-1
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+1
-2
official/nlp/xlnet/data_utils.py
official/nlp/xlnet/data_utils.py
+3
-3
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+1
-1
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+1
-1
official/vision/image_classification/README.md
official/vision/image_classification/README.md
+1
-1
No files found.
official/common/distribute_utils.py
View file @
b6ece654
...
@@ -137,7 +137,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
...
@@ -137,7 +137,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if
distribution_strategy
==
"tpu"
:
if
distribution_strategy
==
"tpu"
:
# When tpu_address is an empty string, we communicate with local TPUs.
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver
=
tpu_initialize
(
tpu_address
)
cluster_resolver
=
tpu_initialize
(
tpu_address
)
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
return
tf
.
distribute
.
TPUStrategy
(
cluster_resolver
)
if
distribution_strategy
==
"multi_worker_mirrored"
:
if
distribution_strategy
==
"multi_worker_mirrored"
:
return
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
(
return
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
(
...
...
official/nlp/bert/model_training_utils.py
View file @
b6ece654
...
@@ -245,7 +245,9 @@ def run_customized_training_loop(
...
@@ -245,7 +245,9 @@ def run_customized_training_loop(
assert
tf
.
executing_eagerly
()
assert
tf
.
executing_eagerly
()
if
run_eagerly
:
if
run_eagerly
:
if
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
):
if
isinstance
(
strategy
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)):
raise
ValueError
(
raise
ValueError
(
'TPUStrategy should not run eagerly as it heavily relies on graph'
'TPUStrategy should not run eagerly as it heavily relies on graph'
' optimization for the distributed system.'
)
' optimization for the distributed system.'
)
...
...
official/nlp/bert/model_training_utils_test.py
View file @
b6ece654
...
@@ -186,7 +186,9 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -186,7 +186,9 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
eager_strategy_combinations
())
@
combinations
.
generate
(
eager_strategy_combinations
())
def
test_train_eager_single_step
(
self
,
distribution
):
def
test_train_eager_single_step
(
self
,
distribution
):
model_dir
=
self
.
create_tempdir
().
full_path
model_dir
=
self
.
create_tempdir
().
full_path
if
isinstance
(
distribution
,
tf
.
distribute
.
experimental
.
TPUStrategy
):
if
isinstance
(
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
self
.
run_training
(
self
.
run_training
(
distribution
,
model_dir
,
steps_per_loop
=
1
,
run_eagerly
=
True
)
distribution
,
model_dir
,
steps_per_loop
=
1
,
run_eagerly
=
True
)
...
...
official/nlp/modeling/models/seq2seq_transformer_test.py
View file @
b6ece654
...
@@ -66,8 +66,9 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -66,8 +66,9 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
mode
=
"eager"
))
mode
=
"eager"
))
def
test_create_model_with_ds
(
self
,
distribution
):
def
test_create_model_with_ds
(
self
,
distribution
):
with
distribution
.
scope
():
with
distribution
.
scope
():
padded_decode
=
isinstance
(
distribution
,
padded_decode
=
isinstance
(
tf
.
distribute
.
experimental
.
TPUStrategy
)
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
))
decode_max_length
=
10
decode_max_length
=
10
batch_size
=
4
batch_size
=
4
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
...
...
official/nlp/nhnet/input_pipeline.py
View file @
b6ece654
...
@@ -218,7 +218,7 @@ def get_input_dataset(input_file_pattern,
...
@@ -218,7 +218,7 @@ def get_input_dataset(input_file_pattern,
# When using TPU pods, we need to clone dataset across
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
TPUStrategy
)
if
use_dataset_fn
:
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
...
official/nlp/nhnet/models_test.py
View file @
b6ece654
...
@@ -179,8 +179,9 @@ class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -179,8 +179,9 @@ class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_bert2bert_eval
(
self
,
distribution
):
def
test_bert2bert_eval
(
self
,
distribution
):
seq_length
=
10
seq_length
=
10
padded_decode
=
isinstance
(
distribution
,
padded_decode
=
isinstance
(
tf
.
distribute
.
experimental
.
TPUStrategy
)
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
))
self
.
_config
.
override
(
self
.
_config
.
override
(
{
{
"beam_size"
:
3
,
"beam_size"
:
3
,
...
@@ -286,8 +287,9 @@ class NHNetTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -286,8 +287,9 @@ class NHNetTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_nhnet_eval
(
self
,
distribution
):
def
test_nhnet_eval
(
self
,
distribution
):
seq_length
=
10
seq_length
=
10
padded_decode
=
isinstance
(
distribution
,
padded_decode
=
isinstance
(
tf
.
distribute
.
experimental
.
TPUStrategy
)
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
))
self
.
_nhnet_config
.
override
(
self
.
_nhnet_config
.
override
(
{
{
"beam_size"
:
4
,
"beam_size"
:
4
,
...
...
official/nlp/nhnet/trainer.py
View file @
b6ece654
...
@@ -210,7 +210,7 @@ def run():
...
@@ -210,7 +210,7 @@ def run():
if
"eval"
in
FLAGS
.
mode
:
if
"eval"
in
FLAGS
.
mode
:
timeout
=
0
if
FLAGS
.
mode
==
"train_and_eval"
else
FLAGS
.
eval_timeout
timeout
=
0
if
FLAGS
.
mode
==
"train_and_eval"
else
FLAGS
.
eval_timeout
# Uses padded decoding for TPU. Always uses cache.
# Uses padded decoding for TPU. Always uses cache.
padded_decode
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
padded_decode
=
isinstance
(
strategy
,
tf
.
distribute
.
TPUStrategy
)
params
.
override
({
params
.
override
({
"padded_decode"
:
padded_decode
,
"padded_decode"
:
padded_decode
,
},
is_strict
=
False
)
},
is_strict
=
False
)
...
...
official/nlp/transformer/transformer_main.py
View file @
b6ece654
...
@@ -182,8 +182,7 @@ class TransformerTask(object):
...
@@ -182,8 +182,7 @@ class TransformerTask(object):
@
property
@
property
def
use_tpu
(
self
):
def
use_tpu
(
self
):
if
self
.
distribution_strategy
:
if
self
.
distribution_strategy
:
return
isinstance
(
self
.
distribution_strategy
,
return
isinstance
(
self
.
distribution_strategy
,
tf
.
distribute
.
TPUStrategy
)
tf
.
distribute
.
experimental
.
TPUStrategy
)
return
False
return
False
def
train
(
self
):
def
train
(
self
):
...
...
official/nlp/xlnet/data_utils.py
View file @
b6ece654
...
@@ -175,7 +175,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
...
@@ -175,7 +175,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
TPUStrategy
)
if
use_dataset_fn
:
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -208,7 +208,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
...
@@ -208,7 +208,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
TPUStrategy
)
if
use_dataset_fn
:
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -592,7 +592,7 @@ def get_pretrain_input_data(batch_size,
...
@@ -592,7 +592,7 @@ def get_pretrain_input_data(batch_size,
# When using TPU pods, we need to clone dataset across
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
TPUStrategy
)
split
=
"train"
split
=
"train"
bsz_per_host
=
int
(
batch_size
/
num_hosts
)
bsz_per_host
=
int
(
batch_size
/
num_hosts
)
record_glob_base
=
format_filename
(
record_glob_base
=
format_filename
(
...
...
official/recommendation/ncf_common.py
View file @
b6ece654
...
@@ -135,7 +135,7 @@ def get_v1_distribution_strategy(params):
...
@@ -135,7 +135,7 @@ def get_v1_distribution_strategy(params):
}
}
os
.
environ
[
"TF_CONFIG"
]
=
json
.
dumps
(
tf_config_env
)
os
.
environ
[
"TF_CONFIG"
]
=
json
.
dumps
(
tf_config_env
)
distribution
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
distribution
=
tf
.
distribute
.
TPUStrategy
(
tpu_cluster_resolver
,
steps_per_run
=
100
)
tpu_cluster_resolver
,
steps_per_run
=
100
)
else
:
else
:
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
b6ece654
...
@@ -135,7 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -135,7 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
if
training
:
if
training
:
# TODO(arashwan): make MeanIoU tpu friendly.
# TODO(arashwan): make MeanIoU tpu friendly.
if
not
isinstance
(
tf
.
distribute
.
get_strategy
(),
if
not
isinstance
(
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
experimental
.
TPUStrategy
):
tf
.
distribute
.
TPUStrategy
):
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
name
=
'mean_iou'
,
name
=
'mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
...
...
official/vision/image_classification/README.md
View file @
b6ece654
...
@@ -43,7 +43,7 @@ builder to 'records' or 'tfds' in the configurations.
...
@@ -43,7 +43,7 @@ builder to 'records' or 'tfds' in the configurations.
Note: These models will
**not**
work with TPUs on Colab.
Note: These models will
**not**
work with TPUs on Colab.
You can train image classification models on Cloud TPUs using
You can train image classification models on Cloud TPUs using
[
tf.distribute.
experimental.
TPUStrategy
](
https://www.tensorflow.org/api_docs/python/tf
/
distribute
/experimental/
TPUStrategy?version=nightly
)
.
[
tf.distribute.TPUStrategy
](
https://www.tensorflow.org/api_docs/python/tf
.
distribute
.
TPUStrategy?version=nightly
)
.
If you are not familiar with Cloud TPUs, it is strongly recommended that you go
If you are not familiar with Cloud TPUs, it is strongly recommended that you go
through the
through the
[
quickstart
](
https://cloud.google.com/tpu/docs/quickstart
)
to learn how to
[
quickstart
](
https://cloud.google.com/tpu/docs/quickstart
)
to learn how to
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment