Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a629af4c
Commit
a629af4c
authored
Sep 06, 2019
by
A. Unique TensorFlower
Browse files
Merge pull request #7531 from vinhngx:amp_bert
PiperOrigin-RevId: 267629422
parents
4dbdb450
b8ceb49c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
101 additions
and
1 deletion
+101
-1
official/bert/benchmark/bert_benchmark.py
official/bert/benchmark/bert_benchmark.py
+37
-0
official/bert/benchmark/bert_squad_benchmark.py
official/bert/benchmark/bert_squad_benchmark.py
+36
-0
official/bert/common_flags.py
official/bert/common_flags.py
+2
-1
official/bert/run_classifier.py
official/bert/run_classifier.py
+9
-0
official/bert/run_pretraining.py
official/bert/run_pretraining.py
+9
-0
official/bert/run_squad.py
official/bert/run_squad.py
+8
-0
No files found.
official/bert/benchmark/bert_benchmark.py
View file @
a629af4c
...
...
@@ -227,6 +227,43 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_1_gpu_amp_mrpc_no_dist_strat
(
self
):
"""Performance for 1 GPU no DS with automatic mixed precision."""
self
.
_setup
()
self
.
num_gpus
=
1
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_amp_mrpc_no_dist_strat'
)
FLAGS
.
train_data_path
=
self
.
train_data_path
FLAGS
.
eval_data_path
=
self
.
eval_data_path
FLAGS
.
input_meta_data_path
=
self
.
input_meta_data_path
FLAGS
.
bert_config_file
=
self
.
bert_config_file
FLAGS
.
train_batch_size
=
4
FLAGS
.
eval_batch_size
=
4
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
,
use_ds
=
False
)
def
benchmark_8_gpu_amp_mrpc
(
self
):
"""Test BERT model performance with 8 GPUs with automatic mixed precision.
"""
self
.
_setup
()
self
.
num_gpus
=
8
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp_mrpc'
)
FLAGS
.
train_data_path
=
self
.
train_data_path
FLAGS
.
eval_data_path
=
self
.
eval_data_path
FLAGS
.
input_meta_data_path
=
self
.
input_meta_data_path
FLAGS
.
bert_config_file
=
self
.
bert_config_file
FLAGS
.
train_batch_size
=
32
FLAGS
.
eval_batch_size
=
32
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
,
use_ds
=
False
)
class
BertClassifyAccuracy
(
BertClassifyBenchmarkBase
):
"""Short accuracy test for BERT model.
...
...
official/bert/benchmark/bert_squad_benchmark.py
View file @
a629af4c
...
...
@@ -281,6 +281,42 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self
.
_run_and_report_benchmark
()
def
benchmark_1_gpu_amp
(
self
):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self
.
_setup
()
self
.
num_gpus
=
1
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_amp_squad'
)
FLAGS
.
train_batch_size
=
4
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
self
.
_run_and_report_benchmark
()
def
benchmark_4_gpu_amp
(
self
):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self
.
_setup
()
self
.
num_gpus
=
4
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4_gpu_amp_squad'
)
FLAGS
.
train_batch_size
=
16
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_amp
(
self
):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self
.
_setup
()
self
.
num_gpus
=
8
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp_squad'
)
FLAGS
.
train_batch_size
=
32
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
self
.
_run_and_report_benchmark
()
class
BertSquadAccuracy
(
BertSquadBenchmarkBase
):
"""Short accuracy test for BERT SQuAD model.
...
...
official/bert/common_flags.py
View file @
a629af4c
...
...
@@ -69,7 +69,8 @@ def define_common_bert_flags():
loss_scale
=
True
,
all_reduce_alg
=
False
,
num_packs
=
False
,
enable_xla
=
True
enable_xla
=
True
,
fp16_implementation
=
True
,
)
...
...
official/bert/run_classifier.py
View file @
a629af4c
...
...
@@ -111,11 +111,20 @@ def run_customized_training(strategy,
drop_remainder
=
False
)
def
_get_classifier_model
():
"""Gets a classifier model."""
classifier_model
,
core_model
=
(
bert_models
.
classifier_model
(
bert_config
,
tf
.
float32
,
num_classes
,
max_seq_length
))
classifier_model
.
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
FLAGS
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model
.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
classifier_model
.
optimizer
)
return
classifier_model
,
core_model
loss_fn
=
get_loss_fn
(
...
...
official/bert/run_pretraining.py
View file @
a629af4c
...
...
@@ -123,10 +123,19 @@ def run_customized_training(strategy,
train_batch_size
,
strategy
)
def
_get_pretrain_model
():
"""Gets a pretraining model."""
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
max_predictions_per_seq
)
pretrain_model
.
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
FLAGS
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model
.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
pretrain_model
.
optimizer
)
return
pretrain_model
,
core_model
trained_model
=
model_training_utils
.
run_customized_training_loop
(
...
...
official/bert/run_squad.py
View file @
a629af4c
...
...
@@ -226,6 +226,14 @@ def train_squad(strategy,
squad_model
.
optimizer
=
(
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
squad_model
.
optimizer
,
loss_scale
=
common_flags
.
get_loss_scale
()))
if
FLAGS
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model
.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
squad_model
.
optimizer
)
return
squad_model
,
core_model
# The original BERT model does not scale the loss by
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment