Commit acdf24c5 authored by Reed's avatar Reed
Browse files

Use new mixed_float16 policy for transformer.

The old infer_float32_policies policy will be removed from TensorFlow soon.
parent a6f10ec0
...@@ -24,24 +24,14 @@ import tensorflow as tf ...@@ -24,24 +24,14 @@ import tensorflow as tf
class EmbeddingSharedWeights(tf.keras.layers.Layer): class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights.""" """Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size, dtype=None): def __init__(self, vocab_size, hidden_size):
"""Specify characteristic parameters of embedding layer. """Specify characteristic parameters of embedding layer.
Args: Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000) vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024) hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
""" """
if dtype == tf.float16: super(EmbeddingSharedWeights, self).__init__()
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype = tf.keras.mixed_precision.experimental.Policy(
"float16_with_float32_vars")
super(EmbeddingSharedWeights, self).__init__(dtype=dtype)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight( self.shared_weights = self.add_weight(
"weights", "weights",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
dtype="float32",
initializer=tf.random_normal_initializer( initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5)) mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape) super(EmbeddingSharedWeights, self).build(input_shape)
......
...@@ -49,8 +49,10 @@ def create_model(params, is_train): ...@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits) logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype="float32")(logits)
model = tf.keras.Model([inputs, targets], logits) model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss( loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size) logits, targets, label_smoothing, vocab_size)
model.add_loss(loss) model.add_loss(loss)
...@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model): ...@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super(Transformer, self).__init__(name=name) super(Transformer, self).__init__(name=name)
self.params = params self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"], dtype=params["dtype"]) params["vocab_size"], params["hidden_size"])
self.encoder_stack = EncoderStack(params) self.encoder_stack = EncoderStack(params)
self.decoder_stack = DecoderStack(params) self.decoder_stack = DecoderStack(params)
......
...@@ -124,8 +124,10 @@ class TransformerTask(object): ...@@ -124,8 +124,10 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created? # like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing # We should have a better way in the tf.keras.mixed_precision API of doing
# this. # this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
policy = tf.keras.mixed_precision.experimental.Policy( policy = tf.keras.mixed_precision.experimental.Policy(
"infer_float32_vars") "mixed_float16", loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
...@@ -372,10 +374,6 @@ class TransformerTask(object): ...@@ -372,10 +374,6 @@ class TransformerTask(object):
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
return opt return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment