Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
415e8a45
Commit
415e8a45
authored
Jun 06, 2019
by
davidmochen
Committed by
saberkun
Jun 06, 2019
Browse files
Add BERT SQuAD benchmark (#6976)
parent
42a8af1d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
280 additions
and
91 deletions
+280
-91
official/bert/benchmark/benchmark_utils.py
official/bert/benchmark/benchmark_utils.py
+120
-0
official/bert/benchmark/bert_benchmark.py
official/bert/benchmark/bert_benchmark.py
+5
-86
official/bert/benchmark/bert_squad_benchmark.py
official/bert/benchmark/bert_squad_benchmark.py
+149
-0
official/bert/model_training_utils.py
official/bert/model_training_utils.py
+3
-3
official/bert/run_squad.py
official/bert/run_squad.py
+3
-2
No files found.
official/bert/benchmark/benchmark_utils.py
0 → 100644
View file @
415e8a45
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
# pylint: disable=g-bad-import-order
import
numpy
as
np
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
FLAGS
=
flags
.
FLAGS
class
BenchmarkTimerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback that records time it takes to run each batch."""
def
__init__
(
self
,
num_batches_to_skip
=
10
):
super
(
BenchmarkTimerCallback
,
self
).
__init__
()
self
.
num_batches_to_skip
=
num_batches_to_skip
self
.
timer_records
=
[]
self
.
start_time
=
None
def
on_batch_start
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
self
.
start_time
=
time
.
time
()
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
assert
self
.
start_time
self
.
timer_records
.
append
(
time
.
time
()
-
self
.
start_time
)
def
get_examples_per_sec
(
self
,
batch_size
):
return
batch_size
/
np
.
mean
(
self
.
timer_records
)
class
BertBenchmarkBase
(
tf
.
test
.
Benchmark
):
"""Base class to hold methods common to test classes."""
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
self
.
num_gpus
=
8
if
not
output_dir
:
output_dir
=
'/tmp'
self
.
output_dir
=
output_dir
self
.
timer_callback
=
None
def
_get_model_dir
(
self
,
folder_name
):
"""Returns directory to store info, e.g. saved model and event log."""
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
self
.
timer_callback
=
BenchmarkTimerCallback
()
if
BertBenchmarkBase
.
local_flags
is
None
:
# Loads flags to get defaults to then override. List cannot be empty.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
BertBenchmarkBase
.
local_flags
=
saved_flag_values
else
:
flagsaver
.
restore_flag_values
(
BertBenchmarkBase
.
local_flags
)
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics
=
[{
'name'
:
'training_loss'
,
'value'
:
stats
[
'train_loss'
],
},
{
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
}]
if
'eval_metrics'
in
stats
:
metrics
.
append
({
'name'
:
'eval_accuracy'
,
'value'
:
stats
[
'eval_metrics'
],
'min_value'
:
min_accuracy
,
'max_value'
:
max_accuracy
,
})
self
.
report_benchmark
(
iters
=
stats
[
'total_training_steps'
],
wall_time
=
wall_time_sec
,
metrics
=
metrics
)
official/bert/benchmark/bert_benchmark.py
View file @
415e8a45
...
...
@@ -24,7 +24,6 @@ import os
import
time
# pylint: disable=g-bad-import-order
import
numpy
as
np
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
...
...
@@ -32,6 +31,7 @@ import tensorflow as tf
from
official.bert
import
modeling
from
official.bert
import
run_classifier
from
official.bert.benchmark
import
benchmark_utils
from
official.utils.misc
import
distribution_utils
# pylint: disable=line-too-long
...
...
@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1
FLAGS
=
flags
.
FLAGS
class
BenchmarkTimerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback that records time it takes to run each batch."""
def
__init__
(
self
,
num_batches_to_skip
=
10
):
super
(
BenchmarkTimerCallback
,
self
).
__init__
()
self
.
num_batches_to_skip
=
num_batches_to_skip
self
.
timer_records
=
[]
self
.
start_time
=
None
def
on_batch_start
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
self
.
start_time
=
time
.
time
()
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
assert
self
.
start_time
self
.
timer_records
.
append
(
time
.
time
()
-
self
.
start_time
)
def
get_examples_per_sec
(
self
,
batch_size
):
return
batch_size
/
np
.
mean
(
self
.
timer_records
)
class
BertBenchmarkBase
(
tf
.
test
.
Benchmark
):
class
BertClassifyBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
self
.
num_gpus
=
8
self
.
num_epochs
=
None
self
.
num_steps_per_epoch
=
None
if
not
output_dir
:
output_dir
=
'/tmp'
self
.
output_dir
=
output_dir
self
.
timer_callback
=
None
def
_get_model_dir
(
self
,
folder_name
):
"""Returns directory to store info, e.g. saved model and event log."""
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
self
.
timer_callback
=
BenchmarkTimerCallback
()
if
BertBenchmarkBase
.
local_flags
is
None
:
# Loads flags to get defaults to then override. List cannot be empty.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
BertBenchmarkBase
.
local_flags
=
saved_flag_values
else
:
flagsaver
.
restore_flag_values
(
BertBenchmarkBase
.
local_flags
)
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics
=
[{
'name'
:
'training_loss'
,
'value'
:
stats
[
'train_loss'
],
},
{
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
}]
if
'eval_metrics'
in
stats
:
metrics
.
append
({
'name'
:
'eval_accuracy'
,
'value'
:
stats
[
'eval_metrics'
],
'min_value'
:
min_accuracy
,
'max_value'
:
max_accuracy
,
})
self
.
report_benchmark
(
iters
=
stats
[
'total_training_steps'
],
wall_time
=
wall_time_sec
,
metrics
=
metrics
)
super
(
BertClassifyBenchmarkBase
,
self
).
__init__
(
output_dir
)
@
flagsaver
.
flagsaver
def
_run_bert_classifier
(
self
,
callbacks
=
None
):
...
...
@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks
=
callbacks
)
class
BertClassifyBenchmarkReal
(
BertBenchmarkBase
):
class
BertClassifyBenchmarkReal
(
Bert
Classify
BenchmarkBase
):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
...
...
@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase):
self
.
_run_and_report_benchmark
(
summary_path
)
class
BertClassifyAccuracy
(
BertBenchmarkBase
):
class
BertClassifyAccuracy
(
Bert
Classify
BenchmarkBase
):
"""Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming
...
...
official/bert/benchmark/bert_squad_benchmark.py
0 → 100644
View file @
415e8a45
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
import
time
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.bert
import
run_squad
from
official.bert.benchmark
import
benchmark_utils
from
official.utils.misc
import
distribution_utils
# pylint: disable=line-too-long
SQUAD_TRAIN_DATA_PATH
=
'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE
=
'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE
=
'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_SMALL_INPUT_META_DATA_PATH
=
'gs://tf-perfzero-data/bert/squad/squad_small_meta_data'
MODEL_CONFIG_FILE_PATH
=
'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long
FLAGS
=
flags
.
FLAGS
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
@
flagsaver
.
flagsaver
def
_run_bert_squad
(
self
):
"""Starts BERT SQuAD task."""
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
self
.
num_gpus
)
run_squad
.
train_squad
(
strategy
=
strategy
,
input_meta_data
=
input_meta_data
,
custom_callbacks
=
[
self
.
timer_callback
])
class
BertSquadBenchmark
(
BertSquadBenchmarkBase
):
"""Short benchmark performance tests for BERT SQuAD model.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
super
(
BertSquadBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
)
def
_setup
(
self
):
super
(
BertSquadBenchmark
,
self
).
_setup
()
FLAGS
.
train_data_path
=
SQUAD_TRAIN_DATA_PATH
FLAGS
.
predict_file
=
SQUAD_PREDICT_FILE
FLAGS
.
vocab_file
=
SQUAD_VOCAB_FILE
FLAGS
.
input_meta_data_path
=
SQUAD_SMALL_INPUT_META_DATA_PATH
FLAGS
.
bert_config_file
=
MODEL_CONFIG_FILE_PATH
FLAGS
.
num_train_epochs
=
1
def
_run_and_report_benchmark
(
self
,
training_summary_path
,
min_accuracy
=
0
,
max_accuracy
=
1
):
"""Starts BERT SQuAD performance benchmark test."""
start_time_sec
=
time
.
time
()
self
.
_run_bert_squad
()
wall_time_sec
=
time
.
time
()
-
start_time_sec
with
tf
.
io
.
gfile
.
GFile
(
training_summary_path
,
'rb'
)
as
reader
:
summary
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
super
(
BertSquadBenchmark
,
self
).
_report_benchmark
(
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
min_accuracy
=
min_accuracy
,
max_accuracy
=
max_accuracy
)
def
benchmark_1_gpu
(
self
):
"""Test BERT SQuAD model performance with 1 GPU."""
self
.
_setup
()
self
.
num_gpus
=
1
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_squad'
)
FLAGS
.
train_batch_size
=
4
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_2_gpu
(
self
):
"""Test BERT SQuAD model performance with 2 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
2
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2_gpu_squad'
)
FLAGS
.
train_batch_size
=
8
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_4_gpu
(
self
):
"""Test BERT SQuAD model performance with 4 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
4
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4_gpu_squad'
)
FLAGS
.
train_batch_size
=
16
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_8_gpu
(
self
):
"""Test BERT SQuAD model performance with 8 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
8
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_squad'
)
FLAGS
.
train_batch_size
=
32
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/bert/model_training_utils.py
View file @
415e8a45
...
...
@@ -287,8 +287,8 @@ def run_customized_training_loop(
if
eval_metric_result
:
training_summary
[
'eval_metrics'
]
=
eval_metric_result
summary_path
=
os
.
path
.
join
(
model_dir
,
SUMMARY_TXT
)
with
tf
.
io
.
gfile
.
GFile
(
summary_path
,
'wb'
)
as
f
:
f
.
write
(
json
.
dumps
(
training_summary
,
indent
=
4
))
summary_path
=
os
.
path
.
join
(
model_dir
,
SUMMARY_TXT
)
with
tf
.
io
.
gfile
.
GFile
(
summary_path
,
'wb'
)
as
f
:
f
.
write
(
json
.
dumps
(
training_summary
,
indent
=
4
))
return
model
official/bert/run_squad.py
View file @
415e8a45
...
...
@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
return
all_results
def
train_squad
(
strategy
,
input_meta_data
):
def
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
None
):
"""Run bert squad training."""
if
not
strategy
:
raise
ValueError
(
'Distribution strategy cannot be None.'
)
...
...
@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data):
epochs
=
epochs
,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
use_remote_tpu
=
use_remote_tpu
)
use_remote_tpu
=
use_remote_tpu
,
custom_callbacks
=
custom_callbacks
)
def
predict_squad
(
strategy
,
input_meta_data
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment