Unverified Commit f8ec01ae authored by Reed's avatar Reed Committed by GitHub
Browse files

Add mixed precision support to Transformer (#7011)

parent 269581dc
...@@ -20,9 +20,13 @@ from __future__ import print_function ...@@ -20,9 +20,13 @@ from __future__ import print_function
import math import math
import numpy as np
import tensorflow as tf import tensorflow as tf
_NEG_INF = -1e9 # Very low numbers to represent -infinity. We do not actually use -Inf, since we
# want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
_NEG_INF_FP32 = -1e9
_NEG_INF_FP16 = np.finfo(np.float16).min
def get_position_encoding( def get_position_encoding(
...@@ -42,6 +46,9 @@ def get_position_encoding( ...@@ -42,6 +46,9 @@ def get_position_encoding(
Returns: Returns:
Tensor with shape [length, hidden_size] Tensor with shape [length, hidden_size]
""" """
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically unstable
# in float16.
position = tf.cast(tf.range(length), tf.float32) position = tf.cast(tf.range(length), tf.float32)
num_timescales = hidden_size // 2 num_timescales = hidden_size // 2
log_timescale_increment = ( log_timescale_increment = (
...@@ -54,7 +61,7 @@ def get_position_encoding( ...@@ -54,7 +61,7 @@ def get_position_encoding(
return signal return signal
def get_decoder_self_attention_bias(length): def get_decoder_self_attention_bias(length, dtype=tf.float32):
"""Calculate bias for decoder that maintains model's autoregressive property. """Calculate bias for decoder that maintains model's autoregressive property.
Creates a tensor that masks out locations that correspond to illegal Creates a tensor that masks out locations that correspond to illegal
...@@ -63,30 +70,34 @@ def get_decoder_self_attention_bias(length): ...@@ -63,30 +70,34 @@ def get_decoder_self_attention_bias(length):
Args: Args:
length: int length of sequences in batch. length: int length of sequences in batch.
dtype: The dtype of the return value.
Returns: Returns:
float tensor of shape [1, 1, length, length] float tensor of shape [1, 1, length, length]
""" """
neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
with tf.name_scope("decoder_self_attention_bias"): with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.linalg.band_part(tf.ones([length, length]), -1, 0) valid_locs = tf.linalg.band_part(tf.ones([length, length], dtype=dtype),
-1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length]) valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = _NEG_INF * (1.0 - valid_locs) decoder_bias = neg_inf * (1.0 - valid_locs)
return decoder_bias return decoder_bias
def get_padding(x, padding_value=0): def get_padding(x, padding_value=0, dtype=tf.float32):
"""Return float tensor representing the padding values in x. """Return float tensor representing the padding values in x.
Args: Args:
x: int tensor with any shape x: int tensor with any shape
padding_value: int value that padding_value: int value that
dtype: The dtype of the return value.
Returns: Returns:
float tensor with same shape as x containing values 0 or 1. float tensor with same shape as x containing values 0 or 1.
0 -> non-padding, 1 -> padding 0 -> non-padding, 1 -> padding
""" """
with tf.name_scope("padding"): with tf.name_scope("padding"):
return tf.cast(tf.equal(x, padding_value), tf.float32) return tf.cast(tf.equal(x, padding_value), dtype)
def get_padding_bias(x): def get_padding_bias(x):
...@@ -104,7 +115,7 @@ def get_padding_bias(x): ...@@ -104,7 +115,7 @@ def get_padding_bias(x):
""" """
with tf.name_scope("attention_bias"): with tf.name_scope("attention_bias"):
padding = get_padding(x) padding = get_padding(x)
attention_bias = padding * _NEG_INF attention_bias = padding * _NEG_INF_FP32
attention_bias = tf.expand_dims( attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1) tf.expand_dims(attention_bias, axis=1), axis=1)
return attention_bias return attention_bias
...@@ -21,6 +21,24 @@ from __future__ import print_function ...@@ -21,6 +21,24 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def _float32_softmax(logits, name=None):
"""Computes a softmax activation in float32.
When training a model using float16, softmax is still done in float32 for
numeric stability.
Args:
logits: A tensor, with any shape accepted by `tf.nn.softmax`.
Returns:
A tensor with the same dtype as `logits`.
"""
input_dtype = logits.dtype
logits = tf.cast(logits, tf.float32)
output = tf.nn.softmax(logits, name=name)
return tf.cast(output, input_dtype)
class Attention(tf.keras.layers.Layer): class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer.""" """Multi-headed attention layer."""
...@@ -129,8 +147,8 @@ class Attention(tf.keras.layers.Layer): ...@@ -129,8 +147,8 @@ class Attention(tf.keras.layers.Layer):
if cache is not None: if cache is not None:
# Combine cached keys and values with new keys and values. # Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1) k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
v = tf.concat([cache["v"], v], axis=1) v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1)
# Update cache # Update cache
cache["k"] = k cache["k"] = k
...@@ -148,7 +166,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -148,7 +166,7 @@ class Attention(tf.keras.layers.Layer):
# Calculate dot product attention # Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) logits = tf.matmul(q, k, transpose_b=True)
logits += bias logits += bias
weights = tf.nn.softmax(logits, name="attention_weights") weights = _float32_softmax(logits, name="attention_weights")
if training: if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout) weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v) attention_output = tf.matmul(weights, v)
......
...@@ -68,7 +68,8 @@ def define_transformer_flags(): ...@@ -68,7 +68,8 @@ def define_transformer_flags():
intra_op=False, intra_op=False,
synthetic_data=True, synthetic_data=True,
max_train_steps=False, max_train_steps=False,
dtype=False, dtype=True,
loss_scale=True,
all_reduce_alg=True, all_reduce_alg=True,
enable_xla=True enable_xla=True
) )
......
...@@ -102,6 +102,7 @@ class Transformer(tf.keras.Model): ...@@ -102,6 +102,7 @@ class Transformer(tf.keras.Model):
returns a dictionary { returns a dictionary {
outputs: [batch_size, decoded length] outputs: [batch_size, decoded length]
scores: [batch_size, float]} scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
""" """
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] inputs, targets = inputs[0], inputs[1]
...@@ -141,12 +142,15 @@ class Transformer(tf.keras.Model): ...@@ -141,12 +142,15 @@ class Transformer(tf.keras.Model):
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
# applying dropout. # applying dropout.
embedded_inputs = self.embedding_softmax_layer(inputs) embedded_inputs = self.embedding_softmax_layer(inputs)
embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"])
inputs_padding = model_utils.get_padding(inputs) inputs_padding = model_utils.get_padding(inputs)
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("add_pos_encoding"): with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1] length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding( pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"]) length, self.params["hidden_size"])
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = embedded_inputs + pos_encoding
if training: if training:
...@@ -174,21 +178,25 @@ class Transformer(tf.keras.Model): ...@@ -174,21 +178,25 @@ class Transformer(tf.keras.Model):
# Prepare inputs to decoder layers by shifting targets, adding positional # Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout. # encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets) decoder_inputs = self.embedding_softmax_layer(targets)
decoder_inputs = tf.cast(decoder_inputs, self.params['dtype'])
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("shift_targets"): with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, decoder_inputs = tf.pad(decoder_inputs,
[[0, 0], [1, 0], [0, 0]])[:, :-1, :] [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"): with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding( pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"]) length, self.params["hidden_size"])
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
decoder_inputs += pos_encoding
if training: if training:
decoder_inputs = tf.nn.dropout( decoder_inputs = tf.nn.dropout(
decoder_inputs, rate=self.params["layer_postprocess_dropout"]) decoder_inputs, rate=self.params["layer_postprocess_dropout"])
# Run values # Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
length) length, dtype=self.params['dtype'])
outputs = self.decoder_stack( outputs = self.decoder_stack(
decoder_inputs, decoder_inputs,
encoder_outputs, encoder_outputs,
...@@ -196,6 +204,7 @@ class Transformer(tf.keras.Model): ...@@ -196,6 +204,7 @@ class Transformer(tf.keras.Model):
attention_bias, attention_bias,
training=training) training=training)
logits = self.embedding_softmax_layer(outputs, mode="linear") logits = self.embedding_softmax_layer(outputs, mode="linear")
logits = tf.cast(logits, tf.float32)
return logits return logits
def _get_symbols_to_logits_fn(self, max_decode_length, training): def _get_symbols_to_logits_fn(self, max_decode_length, training):
...@@ -244,6 +253,9 @@ class Transformer(tf.keras.Model): ...@@ -244,6 +253,9 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
# Currently, we always do prediction in float32.
# TODO(reedwm): Add float16 support.
encoder_outputs = tf.cast(encoder_outputs, tf.float32)
batch_size = tf.shape(encoder_outputs)[0] batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1] input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"] max_decode_length = input_length + self.params["extra_decode_length"]
...@@ -295,16 +307,22 @@ class LayerNormalization(tf.keras.layers.Layer): ...@@ -295,16 +307,22 @@ class LayerNormalization(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
"""Builds the layer.""" """Builds the layer."""
# Passing experimental_autocast=False causes these variables to not be
# automatically casted to fp16 when mixed precision is used. Since we use
# float32 in call() for numeric stability, we do not want variables to be
# casted to fp16.
self.scale = self.add_weight( self.scale = self.add_weight(
"layer_norm_scale", "layer_norm_scale",
shape=[self.hidden_size], shape=[self.hidden_size],
dtype="float32", dtype="float32",
initializer=tf.ones_initializer()) initializer=tf.ones_initializer(),
experimental_autocast=False)
self.bias = self.add_weight( self.bias = self.add_weight(
"layer_norm_bias", "layer_norm_bias",
shape=[self.hidden_size], shape=[self.hidden_size],
dtype="float32", dtype="float32",
initializer=tf.zeros_initializer()) initializer=tf.zeros_initializer(),
experimental_autocast=False)
super(LayerNormalization, self).build(input_shape) super(LayerNormalization, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -313,10 +331,13 @@ class LayerNormalization(tf.keras.layers.Layer): ...@@ -313,10 +331,13 @@ class LayerNormalization(tf.keras.layers.Layer):
} }
def call(self, x, epsilon=1e-6): def call(self, x, epsilon=1e-6):
input_dtype = x.dtype
if input_dtype == tf.float16:
x = tf.cast(x, tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True) mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon) norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon)
return norm_x * self.scale + self.bias return tf.cast(norm_x * self.scale + self.bias, input_dtype)
class PrePostProcessingWrapper(tf.keras.layers.Layer): class PrePostProcessingWrapper(tf.keras.layers.Layer):
......
...@@ -118,6 +118,7 @@ class TransformerTask(object): ...@@ -118,6 +118,7 @@ class TransformerTask(object):
params["use_synthetic_data"] = flags_obj.use_synthetic_data params["use_synthetic_data"] = flags_obj.use_synthetic_data
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
...@@ -246,6 +247,10 @@ class TransformerTask(object): ...@@ -246,6 +247,10 @@ class TransformerTask(object):
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
return opt return opt
...@@ -258,6 +263,11 @@ def _ensure_dir(log_dir): ...@@ -258,6 +263,11 @@ def _ensure_dir(log_dir):
def main(_): def main(_):
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj): with logger.benchmark_context(flags_obj):
if flags_core.get_tf_dtype(flags_obj) == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
task = TransformerTask(flags_obj) task = TransformerTask(flags_obj)
if flags_obj.mode == "train": if flags_obj.mode == "train":
task.train() task.train()
......
...@@ -51,12 +51,17 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -51,12 +51,17 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.batch_size = 8 FLAGS.batch_size = 8
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = "off" FLAGS.distribution_strategy = "off"
FLAGS.dtype = "fp32"
self.model_dir = FLAGS.model_dir self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, "vocab") self.vocab_file = os.path.join(temp_dir, "vocab")
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"] self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
self.bleu_source = os.path.join(temp_dir, "bleu_source") self.bleu_source = os.path.join(temp_dir, "bleu_source")
self.bleu_ref = os.path.join(temp_dir, "bleu_ref") self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
self.orig_policy = tf.keras.mixed_precision.experimental.global_policy()
def tearDown(self):
tf.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath): def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
...@@ -82,6 +87,17 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -82,6 +87,17 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
def test_train_2_gpu_fp16(self):
FLAGS.distribution_strategy = "mirrored"
FLAGS.num_gpus = 2
FLAGS.param_set = "base"
FLAGS.dtype = "fp16"
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.train()
def _prepare_files_and_flags(self, *extra_flags): def _prepare_files_and_flags(self, *extra_flags):
# Make log dir. # Make log dir.
if not os.path.exists(self.temp_dir): if not os.path.exists(self.temp_dir):
...@@ -113,6 +129,14 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -113,6 +129,14 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_predict_fp16(self):
self._prepare_files_and_flags("--dtype=fp16")
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.predict()
def test_eval(self): def test_eval(self):
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
......
...@@ -37,6 +37,7 @@ class TransformerV2Test(tf.test.TestCase): ...@@ -37,6 +37,7 @@ class TransformerV2Test(tf.test.TestCase):
params["vocab_size"] = 41 params["vocab_size"] = 41
params["extra_decode_length"] = 2 params["extra_decode_length"] = 2
params["beam_size"] = 3 params["beam_size"] = 3
params["dtype"] = tf.float32
def test_create_model_train(self): def test_create_model_train(self):
model = transformer.create_model(self.params, True) model = transformer.create_model(self.params, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment