Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
415e8a45
Commit
415e8a45
authored
Jun 06, 2019
by
davidmochen
Committed by
saberkun
Jun 06, 2019
Browse files
Add BERT SQuAD benchmark (#6976)
parent
42a8af1d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
280 additions
and
91 deletions
+280
-91
official/bert/benchmark/benchmark_utils.py
official/bert/benchmark/benchmark_utils.py
+120
-0
official/bert/benchmark/bert_benchmark.py
official/bert/benchmark/bert_benchmark.py
+5
-86
official/bert/benchmark/bert_squad_benchmark.py
official/bert/benchmark/bert_squad_benchmark.py
+149
-0
official/bert/model_training_utils.py
official/bert/model_training_utils.py
+3
-3
official/bert/run_squad.py
official/bert/run_squad.py
+3
-2
No files found.
official/bert/benchmark/benchmark_utils.py
0 → 100644
View file @
415e8a45
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
# pylint: disable=g-bad-import-order
import
numpy
as
np
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
FLAGS
=
flags
.
FLAGS
class
BenchmarkTimerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback that records time it takes to run each batch."""
def
__init__
(
self
,
num_batches_to_skip
=
10
):
super
(
BenchmarkTimerCallback
,
self
).
__init__
()
self
.
num_batches_to_skip
=
num_batches_to_skip
self
.
timer_records
=
[]
self
.
start_time
=
None
def
on_batch_start
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
self
.
start_time
=
time
.
time
()
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
assert
self
.
start_time
self
.
timer_records
.
append
(
time
.
time
()
-
self
.
start_time
)
def
get_examples_per_sec
(
self
,
batch_size
):
return
batch_size
/
np
.
mean
(
self
.
timer_records
)
class
BertBenchmarkBase
(
tf
.
test
.
Benchmark
):
"""Base class to hold methods common to test classes."""
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
self
.
num_gpus
=
8
if
not
output_dir
:
output_dir
=
'/tmp'
self
.
output_dir
=
output_dir
self
.
timer_callback
=
None
def
_get_model_dir
(
self
,
folder_name
):
"""Returns directory to store info, e.g. saved model and event log."""
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
self
.
timer_callback
=
BenchmarkTimerCallback
()
if
BertBenchmarkBase
.
local_flags
is
None
:
# Loads flags to get defaults to then override. List cannot be empty.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
BertBenchmarkBase
.
local_flags
=
saved_flag_values
else
:
flagsaver
.
restore_flag_values
(
BertBenchmarkBase
.
local_flags
)
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics
=
[{
'name'
:
'training_loss'
,
'value'
:
stats
[
'train_loss'
],
},
{
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
}]
if
'eval_metrics'
in
stats
:
metrics
.
append
({
'name'
:
'eval_accuracy'
,
'value'
:
stats
[
'eval_metrics'
],
'min_value'
:
min_accuracy
,
'max_value'
:
max_accuracy
,
})
self
.
report_benchmark
(
iters
=
stats
[
'total_training_steps'
],
wall_time
=
wall_time_sec
,
metrics
=
metrics
)
official/bert/benchmark/bert_benchmark.py
View file @
415e8a45
...
@@ -24,7 +24,6 @@ import os
...
@@ -24,7 +24,6 @@ import os
import
time
import
time
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import
numpy
as
np
from
absl
import
flags
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -32,6 +31,7 @@ import tensorflow as tf
...
@@ -32,6 +31,7 @@ import tensorflow as tf
from
official.bert
import
modeling
from
official.bert
import
modeling
from
official.bert
import
run_classifier
from
official.bert
import
run_classifier
from
official.bert.benchmark
import
benchmark_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
# pylint: disable=line-too-long
# pylint: disable=line-too-long
...
@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1
...
@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
class
BenchmarkTimerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
class
BertClassifyBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Callback that records time it takes to run each batch."""
def
__init__
(
self
,
num_batches_to_skip
=
10
):
super
(
BenchmarkTimerCallback
,
self
).
__init__
()
self
.
num_batches_to_skip
=
num_batches_to_skip
self
.
timer_records
=
[]
self
.
start_time
=
None
def
on_batch_start
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
self
.
start_time
=
time
.
time
()
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
<
self
.
num_batches_to_skip
:
return
assert
self
.
start_time
self
.
timer_records
.
append
(
time
.
time
()
-
self
.
start_time
)
def
get_examples_per_sec
(
self
,
batch_size
):
return
batch_size
/
np
.
mean
(
self
.
timer_records
)
class
BertBenchmarkBase
(
tf
.
test
.
Benchmark
):
"""Base class to hold methods common to test classes in the module."""
"""Base class to hold methods common to test classes in the module."""
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
def
__init__
(
self
,
output_dir
=
None
):
self
.
num_gpus
=
8
self
.
num_epochs
=
None
self
.
num_epochs
=
None
self
.
num_steps_per_epoch
=
None
self
.
num_steps_per_epoch
=
None
if
not
output_dir
:
super
(
BertClassifyBenchmarkBase
,
self
).
__init__
(
output_dir
)
output_dir
=
'/tmp'
self
.
output_dir
=
output_dir
self
.
timer_callback
=
None
def
_get_model_dir
(
self
,
folder_name
):
"""Returns directory to store info, e.g. saved model and event log."""
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
self
.
timer_callback
=
BenchmarkTimerCallback
()
if
BertBenchmarkBase
.
local_flags
is
None
:
# Loads flags to get defaults to then override. List cannot be empty.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
BertBenchmarkBase
.
local_flags
=
saved_flag_values
else
:
flagsaver
.
restore_flag_values
(
BertBenchmarkBase
.
local_flags
)
def
_report_benchmark
(
self
,
stats
,
wall_time_sec
,
min_accuracy
,
max_accuracy
):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics
=
[{
'name'
:
'training_loss'
,
'value'
:
stats
[
'train_loss'
],
},
{
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
}]
if
'eval_metrics'
in
stats
:
metrics
.
append
({
'name'
:
'eval_accuracy'
,
'value'
:
stats
[
'eval_metrics'
],
'min_value'
:
min_accuracy
,
'max_value'
:
max_accuracy
,
})
self
.
report_benchmark
(
iters
=
stats
[
'total_training_steps'
],
wall_time
=
wall_time_sec
,
metrics
=
metrics
)
@
flagsaver
.
flagsaver
@
flagsaver
.
flagsaver
def
_run_bert_classifier
(
self
,
callbacks
=
None
):
def
_run_bert_classifier
(
self
,
callbacks
=
None
):
...
@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
...
@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks
=
callbacks
)
custom_callbacks
=
callbacks
)
class
BertClassifyBenchmarkReal
(
BertBenchmarkBase
):
class
BertClassifyBenchmarkReal
(
Bert
Classify
BenchmarkBase
):
"""Short benchmark performance tests for BERT model.
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
Tests BERT classification performance in different GPU configurations.
...
@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase):
...
@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase):
self
.
_run_and_report_benchmark
(
summary_path
)
self
.
_run_and_report_benchmark
(
summary_path
)
class
BertClassifyAccuracy
(
BertBenchmarkBase
):
class
BertClassifyAccuracy
(
Bert
Classify
BenchmarkBase
):
"""Short accuracy test for BERT model.
"""Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming
Tests BERT classification task model accuracy. The naming
...
...
official/bert/benchmark/bert_squad_benchmark.py
0 → 100644
View file @
415e8a45
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
import
time
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.bert
import
run_squad
from
official.bert.benchmark
import
benchmark_utils
from
official.utils.misc
import
distribution_utils
# pylint: disable=line-too-long
SQUAD_TRAIN_DATA_PATH
=
'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE
=
'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE
=
'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_SMALL_INPUT_META_DATA_PATH
=
'gs://tf-perfzero-data/bert/squad/squad_small_meta_data'
MODEL_CONFIG_FILE_PATH
=
'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long
FLAGS
=
flags
.
FLAGS
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
@
flagsaver
.
flagsaver
def
_run_bert_squad
(
self
):
"""Starts BERT SQuAD task."""
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
self
.
num_gpus
)
run_squad
.
train_squad
(
strategy
=
strategy
,
input_meta_data
=
input_meta_data
,
custom_callbacks
=
[
self
.
timer_callback
])
class
BertSquadBenchmark
(
BertSquadBenchmarkBase
):
"""Short benchmark performance tests for BERT SQuAD model.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
super
(
BertSquadBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
)
def
_setup
(
self
):
super
(
BertSquadBenchmark
,
self
).
_setup
()
FLAGS
.
train_data_path
=
SQUAD_TRAIN_DATA_PATH
FLAGS
.
predict_file
=
SQUAD_PREDICT_FILE
FLAGS
.
vocab_file
=
SQUAD_VOCAB_FILE
FLAGS
.
input_meta_data_path
=
SQUAD_SMALL_INPUT_META_DATA_PATH
FLAGS
.
bert_config_file
=
MODEL_CONFIG_FILE_PATH
FLAGS
.
num_train_epochs
=
1
def
_run_and_report_benchmark
(
self
,
training_summary_path
,
min_accuracy
=
0
,
max_accuracy
=
1
):
"""Starts BERT SQuAD performance benchmark test."""
start_time_sec
=
time
.
time
()
self
.
_run_bert_squad
()
wall_time_sec
=
time
.
time
()
-
start_time_sec
with
tf
.
io
.
gfile
.
GFile
(
training_summary_path
,
'rb'
)
as
reader
:
summary
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
super
(
BertSquadBenchmark
,
self
).
_report_benchmark
(
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
min_accuracy
=
min_accuracy
,
max_accuracy
=
max_accuracy
)
def
benchmark_1_gpu
(
self
):
"""Test BERT SQuAD model performance with 1 GPU."""
self
.
_setup
()
self
.
num_gpus
=
1
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_squad'
)
FLAGS
.
train_batch_size
=
4
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_2_gpu
(
self
):
"""Test BERT SQuAD model performance with 2 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
2
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2_gpu_squad'
)
FLAGS
.
train_batch_size
=
8
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_4_gpu
(
self
):
"""Test BERT SQuAD model performance with 4 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
4
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4_gpu_squad'
)
FLAGS
.
train_batch_size
=
16
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_8_gpu
(
self
):
"""Test BERT SQuAD model performance with 8 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
8
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_squad'
)
FLAGS
.
train_batch_size
=
32
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/bert/model_training_utils.py
View file @
415e8a45
official/bert/run_squad.py
View file @
415e8a45
...
@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
...
@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
return
all_results
return
all_results
def
train_squad
(
strategy
,
input_meta_data
):
def
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
None
):
"""Run bert squad training."""
"""Run bert squad training."""
if
not
strategy
:
if
not
strategy
:
raise
ValueError
(
'Distribution strategy cannot be None.'
)
raise
ValueError
(
'Distribution strategy cannot be None.'
)
...
@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data):
...
@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data):
epochs
=
epochs
,
epochs
=
epochs
,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
use_remote_tpu
=
use_remote_tpu
)
use_remote_tpu
=
use_remote_tpu
,
custom_callbacks
=
custom_callbacks
)
def
predict_squad
(
strategy
,
input_meta_data
):
def
predict_squad
(
strategy
,
input_meta_data
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment