Unverified Commit f8ec01ae authored by Reed's avatar Reed Committed by GitHub
Browse files

Add mixed precision support to Transformer (#7011)

parent 269581dc
......@@ -20,9 +20,13 @@ from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
_NEG_INF = -1e9
# Very low numbers to represent -infinity. We do not actually use -Inf, since we
# want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
_NEG_INF_FP32 = -1e9
_NEG_INF_FP16 = np.finfo(np.float16).min
def get_position_encoding(
......@@ -42,6 +46,9 @@ def get_position_encoding(
Returns:
Tensor with shape [length, hidden_size]
"""
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically unstable
# in float16.
position = tf.cast(tf.range(length), tf.float32)
num_timescales = hidden_size // 2
log_timescale_increment = (
......@@ -54,7 +61,7 @@ def get_position_encoding(
return signal
def get_decoder_self_attention_bias(length):
def get_decoder_self_attention_bias(length, dtype=tf.float32):
"""Calculate bias for decoder that maintains model's autoregressive property.
Creates a tensor that masks out locations that correspond to illegal
......@@ -63,30 +70,34 @@ def get_decoder_self_attention_bias(length):
Args:
length: int length of sequences in batch.
dtype: The dtype of the return value.
Returns:
float tensor of shape [1, 1, length, length]
"""
neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
valid_locs = tf.linalg.band_part(tf.ones([length, length], dtype=dtype),
-1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = _NEG_INF * (1.0 - valid_locs)
decoder_bias = neg_inf * (1.0 - valid_locs)
return decoder_bias
def get_padding(x, padding_value=0):
def get_padding(x, padding_value=0, dtype=tf.float32):
"""Return float tensor representing the padding values in x.
Args:
x: int tensor with any shape
padding_value: int value that
dtype: The dtype of the return value.
Returns:
float tensor with same shape as x containing values 0 or 1.
0 -> non-padding, 1 -> padding
"""
with tf.name_scope("padding"):
return tf.cast(tf.equal(x, padding_value), tf.float32)
return tf.cast(tf.equal(x, padding_value), dtype)
def get_padding_bias(x):
......@@ -104,7 +115,7 @@ def get_padding_bias(x):
"""
with tf.name_scope("attention_bias"):
padding = get_padding(x)
attention_bias = padding * _NEG_INF
attention_bias = padding * _NEG_INF_FP32
attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1)
return attention_bias
......@@ -21,6 +21,24 @@ from __future__ import print_function
import tensorflow as tf
def _float32_softmax(logits, name=None):
"""Computes a softmax activation in float32.
When training a model using float16, softmax is still done in float32 for
numeric stability.
Args:
logits: A tensor, with any shape accepted by `tf.nn.softmax`.
Returns:
A tensor with the same dtype as `logits`.
"""
input_dtype = logits.dtype
logits = tf.cast(logits, tf.float32)
output = tf.nn.softmax(logits, name=name)
return tf.cast(output, input_dtype)
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
......@@ -129,8 +147,8 @@ class Attention(tf.keras.layers.Layer):
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1)
v = tf.concat([cache["v"], v], axis=1)
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1)
# Update cache
cache["k"] = k
......@@ -148,7 +166,7 @@ class Attention(tf.keras.layers.Layer):
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
weights = _float32_softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v)
......
......@@ -68,7 +68,8 @@ def define_transformer_flags():
intra_op=False,
synthetic_data=True,
max_train_steps=False,
dtype=False,
dtype=True,
loss_scale=True,
all_reduce_alg=True,
enable_xla=True
)
......
......@@ -102,6 +102,7 @@ class Transformer(tf.keras.Model):
returns a dictionary {
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
"""
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
......@@ -141,12 +142,15 @@ class Transformer(tf.keras.Model):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_softmax_layer(inputs)
embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"])
inputs_padding = model_utils.get_padding(inputs)
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"])
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
encoder_inputs = embedded_inputs + pos_encoding
if training:
......@@ -174,21 +178,25 @@ class Transformer(tf.keras.Model):
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets)
decoder_inputs = tf.cast(decoder_inputs, self.params['dtype'])
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs,
[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding(
pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"])
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
decoder_inputs += pos_encoding
if training:
decoder_inputs = tf.nn.dropout(
decoder_inputs, rate=self.params["layer_postprocess_dropout"])
# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
length)
length, dtype=self.params['dtype'])
outputs = self.decoder_stack(
decoder_inputs,
encoder_outputs,
......@@ -196,6 +204,7 @@ class Transformer(tf.keras.Model):
attention_bias,
training=training)
logits = self.embedding_softmax_layer(outputs, mode="linear")
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length, training):
......@@ -244,6 +253,9 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence."""
# Currently, we always do prediction in float32.
# TODO(reedwm): Add float16 support.
encoder_outputs = tf.cast(encoder_outputs, tf.float32)
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
......@@ -295,16 +307,22 @@ class LayerNormalization(tf.keras.layers.Layer):
def build(self, input_shape):
"""Builds the layer."""
# Passing experimental_autocast=False causes these variables to not be
# automatically casted to fp16 when mixed precision is used. Since we use
# float32 in call() for numeric stability, we do not want variables to be
# casted to fp16.
self.scale = self.add_weight(
"layer_norm_scale",
shape=[self.hidden_size],
dtype="float32",
initializer=tf.ones_initializer())
initializer=tf.ones_initializer(),
experimental_autocast=False)
self.bias = self.add_weight(
"layer_norm_bias",
shape=[self.hidden_size],
dtype="float32",
initializer=tf.zeros_initializer())
initializer=tf.zeros_initializer(),
experimental_autocast=False)
super(LayerNormalization, self).build(input_shape)
def get_config(self):
......@@ -313,10 +331,13 @@ class LayerNormalization(tf.keras.layers.Layer):
}
def call(self, x, epsilon=1e-6):
input_dtype = x.dtype
if input_dtype == tf.float16:
x = tf.cast(x, tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon)
return norm_x * self.scale + self.bias
return tf.cast(norm_x * self.scale + self.bias, input_dtype)
class PrePostProcessingWrapper(tf.keras.layers.Layer):
......
......@@ -118,6 +118,7 @@ class TransformerTask(object):
params["use_synthetic_data"] = flags_obj.use_synthetic_data
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
def train(self):
"""Trains the model."""
......@@ -246,6 +247,10 @@ class TransformerTask(object):
params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
return opt
......@@ -258,6 +263,11 @@ def _ensure_dir(log_dir):
def main(_):
flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj):
if flags_core.get_tf_dtype(flags_obj) == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
task = TransformerTask(flags_obj)
if flags_obj.mode == "train":
task.train()
......
......@@ -51,12 +51,17 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.batch_size = 8
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = "off"
FLAGS.dtype = "fp32"
self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, "vocab")
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
self.bleu_source = os.path.join(temp_dir, "bleu_source")
self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
self.orig_policy = tf.keras.mixed_precision.experimental.global_policy()
def tearDown(self):
tf.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath))
......@@ -82,6 +87,17 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_2_gpu_fp16(self):
FLAGS.distribution_strategy = "mirrored"
FLAGS.num_gpus = 2
FLAGS.param_set = "base"
FLAGS.dtype = "fp16"
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.train()
def _prepare_files_and_flags(self, *extra_flags):
# Make log dir.
if not os.path.exists(self.temp_dir):
......@@ -113,6 +129,14 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS)
t.predict()
def test_predict_fp16(self):
self._prepare_files_and_flags("--dtype=fp16")
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.predict()
def test_eval(self):
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
......
......@@ -37,6 +37,7 @@ class TransformerV2Test(tf.test.TestCase):
params["vocab_size"] = 41
params["extra_decode_length"] = 2
params["beam_size"] = 3
params["dtype"] = tf.float32
def test_create_model_train(self):
model = transformer.create_model(self.params, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment