Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a22d0715
Commit
a22d0715
authored
Feb 25, 2020
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Feb 25, 2020
Browse files
Internal change
PiperOrigin-RevId: 297271153
parent
7b9365dd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
official/benchmark/bert_squad_benchmark.py
official/benchmark/bert_squad_benchmark.py
+14
-5
No files found.
official/benchmark/bert_squad_benchmark.py
View file @
a22d0715
...
@@ -551,6 +551,8 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
...
@@ -551,6 +551,8 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
num_gpus
=
8
num_gpus
=
8
FLAGS
.
num_gpus
=
num_gpus
FLAGS
.
num_gpus
=
num_gpus
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
# Enable gradient allreduce in fp16
FLAGS
.
explicit_allreduce
=
True
FLAGS
.
enable_xla
=
False
FLAGS
.
enable_xla
=
False
FLAGS
.
distribution_strategy
=
'multi_worker_mirrored'
FLAGS
.
distribution_strategy
=
'multi_worker_mirrored'
FLAGS
.
tf_gpu_thread_mode
=
'gpu_private'
FLAGS
.
tf_gpu_thread_mode
=
'gpu_private'
...
@@ -621,7 +623,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
...
@@ -621,7 +623,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
min_accuracy
=
0
,
min_accuracy
=
0
,
max_accuracy
=
1
)
max_accuracy
=
1
)
def
_benchmark_common
(
self
,
num_workers
,
all_reduce_alg
):
def
_benchmark_common
(
self
,
num_workers
,
all_reduce_alg
,
explicit_allreduce
=
False
):
"""Common to all benchmarks in this class."""
"""Common to all benchmarks in this class."""
self
.
_setup
()
self
.
_setup
()
...
@@ -637,6 +640,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
...
@@ -637,6 +640,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
num_workers
,
all_reduce_alg
))
num_workers
,
all_reduce_alg
))
FLAGS
.
train_batch_size
=
4
*
num_gpus
*
num_workers
FLAGS
.
train_batch_size
=
4
*
num_gpus
*
num_workers
FLAGS
.
all_reduce_alg
=
all_reduce_alg
FLAGS
.
all_reduce_alg
=
all_reduce_alg
# Enable gradient allreduce in fp16
FLAGS
.
explicit_allreduce
=
explicit_allreduce
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
...
@@ -650,19 +655,23 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
...
@@ -650,19 +655,23 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def
benchmark_8_gpu_2_workers_fp16_ring_tweaked
(
self
):
def
benchmark_8_gpu_2_workers_fp16_ring_tweaked
(
self
):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self
.
_benchmark_common
(
num_workers
=
2
,
all_reduce_alg
=
'ring'
)
self
.
_benchmark_common
(
num_workers
=
2
,
all_reduce_alg
=
'ring'
,
explicit_allreduce
=
True
)
def
benchmark_8_gpu_2_workers_fp16_nccl_tweaked
(
self
):
def
benchmark_8_gpu_2_workers_fp16_nccl_tweaked
(
self
):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self
.
_benchmark_common
(
num_workers
=
2
,
all_reduce_alg
=
'nccl'
)
self
.
_benchmark_common
(
num_workers
=
2
,
all_reduce_alg
=
'nccl'
,
explicit_allreduce
=
True
)
def
benchmark_8_gpu_8_workers_fp16_ring_tweaked
(
self
):
def
benchmark_8_gpu_8_workers_fp16_ring_tweaked
(
self
):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self
.
_benchmark_common
(
num_workers
=
8
,
all_reduce_alg
=
'ring'
)
self
.
_benchmark_common
(
num_workers
=
8
,
all_reduce_alg
=
'ring'
,
explicit_allreduce
=
True
)
def
benchmark_8_gpu_8_workers_fp16_nccl_tweaked
(
self
):
def
benchmark_8_gpu_8_workers_fp16_nccl_tweaked
(
self
):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self
.
_benchmark_common
(
num_workers
=
8
,
all_reduce_alg
=
'nccl'
)
self
.
_benchmark_common
(
num_workers
=
8
,
all_reduce_alg
=
'nccl'
,
explicit_allreduce
=
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment