Commit acdf24c5 authored by Reed's avatar Reed
Browse files

Use new mixed_float16 policy for transformer.

The old infer_float32_policies policy will be removed from TensorFlow soon.
parent a6f10ec0
......@@ -24,24 +24,14 @@ import tensorflow as tf
class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size, dtype=None):
def __init__(self, vocab_size, hidden_size):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
if dtype == tf.float16:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype = tf.keras.mixed_precision.experimental.Policy(
"float16_with_float32_vars")
super(EmbeddingSharedWeights, self).__init__(dtype=dtype)
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
......@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight(
"weights",
shape=[self.vocab_size, self.hidden_size],
dtype="float32",
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape)
......
......@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype="float32")(logits)
model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size)
model.add_loss(loss)
......@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super(Transformer, self).__init__(name=name)
self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"], dtype=params["dtype"])
params["vocab_size"], params["hidden_size"])
self.encoder_stack = EncoderStack(params)
self.decoder_stack = DecoderStack(params)
......
......@@ -124,8 +124,10 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
policy = tf.keras.mixed_precision.experimental.Policy(
"infer_float32_vars")
"mixed_float16", loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
......@@ -372,10 +374,6 @@ class TransformerTask(object):
params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment