Unverified Commit 8c7a0e75 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7309)

260060237  by zongweiz<zongweiz@google.com>:

    [BERT SQuAD] Enable mixed precision training

    Add mixed precision training support for BERT SQuAD model. Using the experimental Keras mixed precision API. For numeric stability, use fp32 for layer normalization, dense layers with GELU activation, etc.

--

PiperOrigin-RevId: 260060237
parent 745a06a9
......@@ -218,6 +218,54 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_2_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 2 GPUs and FP16."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad_fp16')
FLAGS.train_batch_size = 8
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_4_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 4 GPUs and FP16."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad_fp16')
FLAGS.train_batch_size = 16
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model.
......@@ -281,6 +329,18 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_8_gpu_xla(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
......
......@@ -318,6 +318,8 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer):
logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
if self.float_type == tf.float16:
unstacked_logits = tf.cast(unstacked_logits, tf.float32)
return unstacked_logits[0], unstacked_logits[1]
......
......@@ -15,6 +15,9 @@
"""Defining common flags used across all BERT models/applications."""
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
def define_common_bert_flags():
......@@ -42,3 +45,26 @@ def define_common_bert_flags():
'inside.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
# add flags for mixed precision training.
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=True,
dynamic_loss_scale=True,
loss_scale=True,
all_reduce_alg=False,
num_packs=False,
enable_xla=False
)
def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
......@@ -205,6 +205,8 @@ def run_customized_training_loop(
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
use_float16 = isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
if init_checkpoint:
logging.info(
......@@ -242,9 +244,15 @@ def run_customized_training_loop(
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues.
tvars = list(set(model.trainable_variables))
if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
......
......@@ -156,7 +156,7 @@ class BertModel(tf.keras.layers.Layer):
vocab_size=self.config.vocab_size,
embedding_size=self.config.hidden_size,
initializer_range=self.config.initializer_range,
dtype=self.float_type,
dtype=tf.float32,
name="word_embeddings")
self.embedding_postprocessor = EmbeddingPostprocessor(
use_type_embeddings=True,
......@@ -176,6 +176,7 @@ class BertModel(tf.keras.layers.Layer):
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
backward_compatible=self.config.backward_compatible,
float_type=self.float_type,
name="encoder")
self.pooler_transform = tf.keras.layers.Dense(
units=self.config.hidden_size,
......@@ -202,6 +203,8 @@ class BertModel(tf.keras.layers.Layer):
word_embeddings = self.embedding_lookup(input_word_ids)
embedding_tensor = self.embedding_postprocessor(
word_embeddings=word_embeddings, token_type_ids=input_type_ids)
if self.float_type == tf.float16:
embedding_tensor = tf.cast(embedding_tensor, tf.float16)
attention_mask = None
if input_mask is not None:
attention_mask = create_attention_mask_from_input_mask(
......@@ -441,7 +444,7 @@ class Attention(tf.keras.layers.Layer):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, self.dtype)) * -10000.0
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
......@@ -654,6 +657,7 @@ class TransformerBlock(tf.keras.layers.Layer):
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
backward_compatible=False,
float_type=tf.float32,
**kwargs):
super(TransformerBlock, self).__init__(**kwargs)
self.hidden_size = hidden_size
......@@ -664,6 +668,7 @@ class TransformerBlock(tf.keras.layers.Layer):
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
self.float_type = float_type
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
......@@ -719,13 +724,24 @@ class TransformerBlock(tf.keras.layers.Layer):
attention_mask=attention_mask)
attention_output = self.attention_output_dense(attention_output)
attention_output = self.attention_dropout(attention_output)
# Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability
if self.float_type == tf.float16:
input_tensor = tf.cast(input_tensor, tf.float32)
attention_output = tf.cast(attention_output, tf.float32)
attention_output = self.attention_layer_norm(input_tensor +
attention_output)
intermediate_output = self.intermediate_dense(attention_output)
if self.float_type == tf.float16:
intermediate_output = tf.cast(intermediate_output, tf.float16)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
# Use float32 in keras layer norm for numeric stability
if self.float_type == tf.float16:
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self.output_layer_norm(layer_output + attention_output)
if self.float_type == tf.float16:
layer_output = tf.cast(layer_output, tf.float16)
return layer_output
......@@ -751,6 +767,7 @@ class Transformer(tf.keras.layers.Layer):
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
backward_compatible=False,
float_type=tf.float32,
**kwargs):
super(Transformer, self).__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
......@@ -762,6 +779,7 @@ class Transformer(tf.keras.layers.Layer):
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
self.float_type = float_type
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
......@@ -777,6 +795,7 @@ class Transformer(tf.keras.layers.Layer):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
backward_compatible=self.backward_compatible,
float_type=self.float_type,
name=("layer_%d" % i)))
super(Transformer, self).build(unused_input_shapes)
......
......@@ -81,7 +81,7 @@ def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_scale=1.0):
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.backend.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
......@@ -89,11 +89,11 @@ def squad_loss_fn(start_positions,
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_scale
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_scale=1.0):
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
......@@ -105,7 +105,7 @@ def get_loss_fn(loss_scale=1.0):
end_positions,
start_logits,
end_logits,
loss_scale=loss_scale)
loss_factor=loss_factor)
return _loss_fn
......@@ -182,6 +182,11 @@ def train_squad(strategy,
logging.info('Training using customized training loop with distribution'
' strategy.')
use_float16 = common_flags.use_float16()
if use_float16:
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
......@@ -196,17 +201,24 @@ def train_squad(strategy,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config, max_seq_length, float_type=tf.float32)
bert_config,
max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(loss_scale=1.0)
loss_fn = get_loss_fn(loss_factor=1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment