Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0612e190
Commit
0612e190
authored
Dec 17, 2019
by
David Chen
Committed by
A. Unique TensorFlower
Dec 17, 2019
Browse files
Internal change
PiperOrigin-RevId: 286087529
parent
7e67dbbc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
28 deletions
+41
-28
official/benchmark/bert_benchmark_utils.py
official/benchmark/bert_benchmark_utils.py
+4
-19
official/benchmark/bert_squad_benchmark.py
official/benchmark/bert_squad_benchmark.py
+37
-9
No files found.
official/benchmark/bert_benchmark_utils.py
View file @
0612e190
...
@@ -18,17 +18,16 @@ from __future__ import absolute_import
...
@@ -18,17 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
time
import
time
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import
numpy
as
np
import
numpy
as
np
from
absl
import
flags
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -59,34 +58,20 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
...
@@ -59,34 +58,20 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
return
self
.
batch_start_times
[
0
]
-
program_start_time
return
self
.
batch_start_times
[
0
]
-
program_start_time
class
BertBenchmarkBase
(
tf
.
test
.
Benchmark
):
class
BertBenchmarkBase
(
PerfZero
Benchmark
):
"""Base class to hold methods common to test classes."""
"""Base class to hold methods common to test classes."""
local_flags
=
None
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
def
__init__
(
self
,
output_dir
=
None
):
super
(
BertBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
)
self
.
num_gpus
=
8
self
.
num_gpus
=
8
if
not
output_dir
:
output_dir
=
'/tmp'
self
.
output_dir
=
output_dir
self
.
timer_callback
=
None
self
.
timer_callback
=
None
def
_get_model_dir
(
self
,
folder_name
):
"""Returns directory to store info, e.g. saved model and event log."""
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
super
(
BertBenchmarkBase
,
self
).
_setup
()
self
.
timer_callback
=
BenchmarkTimerCallback
()
self
.
timer_callback
=
BenchmarkTimerCallback
()
if
BertBenchmarkBase
.
local_flags
is
None
:
# Loads flags to get defaults to then override. List cannot be empty.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
BertBenchmarkBase
.
local_flags
=
saved_flag_values
else
:
flagsaver
.
restore_flag_values
(
BertBenchmarkBase
.
local_flags
)
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
"""Report benchmark results by writing to local protobuf file.
"""Report benchmark results by writing to local protobuf file.
...
...
official/benchmark/bert_squad_benchmark.py
View file @
0612e190
...
@@ -52,6 +52,10 @@ FLAGS = flags.FLAGS
...
@@ -52,6 +52,10 @@ FLAGS = flags.FLAGS
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
"""Base class to hold methods common to test classes in the module."""
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
):
super
(
BertSquadBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
)
self
.
tpu
=
tpu
def
_read_training_summary_from_file
(
self
):
def
_read_training_summary_from_file
(
self
):
"""Reads the training summary from a file."""
"""Reads the training summary from a file."""
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
...
@@ -78,6 +82,10 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
...
@@ -78,6 +82,10 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def
_get_distribution_strategy
(
self
,
use_ds
=
True
):
def
_get_distribution_strategy
(
self
,
use_ds
=
True
):
"""Gets the distribution strategy."""
"""Gets the distribution strategy."""
if
self
.
tpu
:
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'tpu'
,
tpu_address
=
self
.
tpu
)
else
:
return
distribution_utils
.
get_distribution_strategy
(
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
if
use_ds
else
'off'
,
distribution_strategy
=
'mirrored'
if
use_ds
else
'off'
,
num_gpus
=
self
.
num_gpus
)
num_gpus
=
self
.
num_gpus
)
...
@@ -117,11 +125,12 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
...
@@ -117,11 +125,12 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
Tests BERT SQuAD performance in different GPU configurations.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
"""
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
)
super
(
BertSquadBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
@@ -322,16 +331,26 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
...
@@ -322,16 +331,26 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_2x2_tpu
(
self
):
"""Tests BERT SQuAD model performance with 2x2 TPU."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
FLAGS
.
train_batch_size
=
48
self
.
_run_and_report_benchmark
()
class
BertSquadAccuracy
(
BertSquadBenchmarkBase
):
class
BertSquadAccuracy
(
BertSquadBenchmarkBase
):
"""Short accuracy test for BERT SQuAD model.
"""Short accuracy test for BERT SQuAD model.
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
)
super
(
BertSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
@@ -407,6 +426,15 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
...
@@ -407,6 +426,15 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_2x2_tpu
(
self
):
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
FLAGS
.
train_batch_size
=
48
self
.
_run_and_report_benchmark
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment