Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c5ad244e
Commit
c5ad244e
authored
Feb 24, 2020
by
Yanhui Liang
Committed by
A. Unique TensorFlower
Feb 24, 2020
Browse files
Modify `_get_distribution_strategy` for multi-worker benchmark.
PiperOrigin-RevId: 297016418
parent
fb35d6be
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
26 deletions
+40
-26
official/benchmark/bert_squad_benchmark.py
official/benchmark/bert_squad_benchmark.py
+40
-26
No files found.
official/benchmark/bert_squad_benchmark.py
View file @
c5ad244e
...
...
@@ -82,15 +82,27 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with
tf
.
io
.
gfile
.
GFile
(
predictions_file
,
'r'
)
as
reader
:
return
json
.
load
(
reader
)
def
_get_distribution_strategy
(
self
,
use_ds
=
True
):
"""Gets the distribution strategy."""
if
self
.
tpu
:
def
_get_distribution_strategy
(
self
,
ds_type
=
'mirrored'
):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if
self
.
tpu
or
ds_type
==
'tpu'
:
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'tpu'
,
tpu_address
=
self
.
tpu
)
else
:
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
if
use_ds
else
'off'
,
num_gpus
=
self
.
num_gpus
)
elif
ds_type
==
'multi_worker_mirrored'
:
# Configures cluster spec for multi-worker distribution strategy.
_
=
distribution_utils
.
configure_cluster
(
FLAGS
.
worker_hosts
,
FLAGS
.
task_index
)
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
ds_type
,
num_gpus
=
self
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
)
def
_init_gpu_and_data_threads
(
self
):
"""Set env variables before any TF calls."""
...
...
@@ -102,12 +114,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
datasets_num_private_threads
=
FLAGS
.
datasets_num_private_threads
)
@
flagsaver
.
flagsaver
def
_train_squad
(
self
,
use_ds
=
True
,
run_eagerly
=
False
):
"""Runs BERT SQuAD training."""
def
_train_squad
(
self
,
run_eagerly
=
False
,
ds_type
=
'mirrored'
):
"""Runs BERT SQuAD training.
Uses mirrored strategy by default.
"""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
self
.
_init_gpu_and_data_threads
()
input_meta_data
=
self
.
_read_input_meta_data_from_file
()
strategy
=
self
.
_get_distribution_strategy
(
use_ds
)
strategy
=
self
.
_get_distribution_strategy
(
ds_type
)
run_squad
.
train_squad
(
strategy
=
strategy
,
...
...
@@ -116,12 +128,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
custom_callbacks
=
[
self
.
timer_callback
])
@
flagsaver
.
flagsaver
def
_evaluate_squad
(
self
,
use_ds
=
True
):
"""Runs BERT SQuAD evaluation."""
def
_evaluate_squad
(
self
,
ds_type
=
'mirrored'
):
"""Runs BERT SQuAD evaluation.
Uses mirrored strategy by default.
"""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
self
.
_init_gpu_and_data_threads
()
input_meta_data
=
self
.
_read_input_meta_data_from_file
()
strategy
=
self
.
_get_distribution_strategy
(
use_ds
)
strategy
=
self
.
_get_distribution_strategy
(
ds_type
)
run_squad
.
predict_squad
(
strategy
=
strategy
,
input_meta_data
=
input_meta_data
)
...
...
@@ -157,15 +169,15 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
,
use_ds
=
Tru
e
,
run_eagerly
=
False
):
run_eagerly
=
Fals
e
,
ds_type
=
'mirrored'
):
"""Runs the benchmark and reports various metrics."""
if
FLAGS
.
train_batch_size
<=
4
:
FLAGS
.
input_meta_data_path
=
SQUAD_MEDIUM_INPUT_META_DATA_PATH
else
:
FLAGS
.
input_meta_data_path
=
SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec
=
time
.
time
()
self
.
_train_squad
(
use_ds
=
use_ds
,
run_eagerly
=
run_eagerly
)
self
.
_train_squad
(
run_eagerly
=
run_eagerly
,
ds_type
=
ds_type
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
summary
=
self
.
_read_training_summary_from_file
()
...
...
@@ -217,7 +229,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_no_dist_strat_squad'
)
FLAGS
.
train_batch_size
=
4
self
.
_run_and_report_benchmark
(
use_ds
=
False
)
self
.
_run_and_report_benchmark
(
ds_type
=
'off'
)
def
benchmark_1_gpu_eager_no_dist_strat
(
self
):
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
...
...
@@ -228,7 +240,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
'benchmark_1_gpu_eager_no_dist_strat_squad'
)
FLAGS
.
train_batch_size
=
4
self
.
_run_and_report_benchmark
(
use_ds
=
False
,
run_eagerly
=
True
)
self
.
_run_and_report_benchmark
(
ds_type
=
'off'
,
run_eagerly
=
True
)
def
benchmark_2_gpu
(
self
):
"""Tests BERT SQuAD model performance with 2 GPUs."""
...
...
@@ -420,12 +432,12 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
,
use_ds
=
Tru
e
,
run_eagerly
=
False
):
run_eagerly
=
Fals
e
,
ds_type
=
'mirrored'
):
"""Runs the benchmark and reports various metrics."""
start_time_sec
=
time
.
time
()
self
.
_train_squad
(
use_ds
=
use_ds
,
run_eagerly
=
run_eagerly
)
self
.
_evaluate_squad
()
self
.
_train_squad
(
run_eagerly
=
run_eagerly
,
ds_type
=
ds_type
)
self
.
_evaluate_squad
(
ds_type
=
ds_type
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
summary
=
self
.
_read_training_summary_from_file
()
...
...
@@ -445,7 +457,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_squad_eager'
)
FLAGS
.
train_batch_size
=
4
self
.
_run_and_report_benchmark
(
use_ds
=
False
,
run_eagerly
=
True
)
self
.
_run_and_report_benchmark
(
ds_type
=
'off'
,
run_eagerly
=
True
)
def
benchmark_8_gpu
(
self
):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
...
...
@@ -518,8 +530,9 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
run_eagerly
=
False
):
"""Runs the benchmark and reports various metrics."""
start_time_sec
=
time
.
time
()
self
.
_train_squad
(
use_ds
=
use_ds
,
run_eagerly
=
run_eagerly
)
self
.
_evaluate_squad
()
self
.
_train_squad
(
run_eagerly
=
run_eagerly
,
ds_type
=
'multi_worker_mirrored'
)
self
.
_evaluate_squad
(
ds_type
=
'multi_worker_mirrored'
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
summary
=
self
.
_read_training_summary_from_file
()
...
...
@@ -595,7 +608,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
else
:
FLAGS
.
input_meta_data_path
=
SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec
=
time
.
time
()
self
.
_train_squad
(
use_ds
=
use_ds
,
run_eagerly
=
run_eagerly
)
self
.
_train_squad
(
run_eagerly
=
run_eagerly
,
ds_type
=
'multi_worker_mirrored'
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
summary
=
self
.
_read_training_summary_from_file
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment