"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "ed618bb0661a86d45e039dec2d104110a852e9f0"
Commit a629af4c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7531 from vinhngx:amp_bert

PiperOrigin-RevId: 267629422
parents 4dbdb450 b8ceb49c
...@@ -227,6 +227,43 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -227,6 +227,43 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt') summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path) self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
"""Performance for 1 GPU no DS with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_amp_mrpc_no_dist_strat')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self):
"""Test BERT model performance with 8 GPUs with automatic mixed precision.
"""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
class BertClassifyAccuracy(BertClassifyBenchmarkBase): class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model. """Short accuracy test for BERT model.
......
...@@ -281,6 +281,42 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -281,6 +281,42 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_4_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_amp_squad')
FLAGS.train_batch_size = 16
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase): class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model. """Short accuracy test for BERT SQuAD model.
......
...@@ -69,7 +69,8 @@ def define_common_bert_flags(): ...@@ -69,7 +69,8 @@ def define_common_bert_flags():
loss_scale=True, loss_scale=True,
all_reduce_alg=False, all_reduce_alg=False,
num_packs=False, num_packs=False,
enable_xla=True enable_xla=True,
fp16_implementation=True,
) )
......
...@@ -111,11 +111,20 @@ def run_customized_training(strategy, ...@@ -111,11 +111,20 @@ def run_customized_training(strategy,
drop_remainder=False) drop_remainder=False)
def _get_classifier_model(): def _get_classifier_model():
"""Gets a classifier model."""
classifier_model, core_model = ( classifier_model, core_model = (
bert_models.classifier_model(bert_config, tf.float32, num_classes, bert_models.classifier_model(bert_config, tf.float32, num_classes,
max_seq_length)) max_seq_length))
classifier_model.optimizer = optimization.create_optimizer( classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
return classifier_model, core_model return classifier_model, core_model
loss_fn = get_loss_fn( loss_fn = get_loss_fn(
......
...@@ -123,10 +123,19 @@ def run_customized_training(strategy, ...@@ -123,10 +123,19 @@ def run_customized_training(strategy,
train_batch_size, strategy) train_batch_size, strategy)
def _get_pretrain_model(): def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer( pretrain_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
......
...@@ -226,6 +226,14 @@ def train_squad(strategy, ...@@ -226,6 +226,14 @@ def train_squad(strategy,
squad_model.optimizer = ( squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer( tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale())) squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model return squad_model, core_model
# The original BERT model does not scale the loss by # The original BERT model does not scale the loss by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment