Commit 7332c1ca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add norm_first to Transformer Scaffold; Add an option in gated_feedword to...

Add norm_first to Transformer Scaffold; Add an option in gated_feedword to disable the output layer_norm.

PiperOrigin-RevId: 333591020
parent ad615d4c
......@@ -59,6 +59,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
intermediate_activation,
dropout,
use_gate=True,
apply_output_layer_norm=True,
num_blocks=1,
dropout_position="before_residual",
kernel_initializer="glorot_uniform",
......@@ -75,6 +76,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._dropout = dropout
self._use_gate = use_gate
self._num_blocks = num_blocks
self._apply_output_layer_norm = apply_output_layer_norm
self._dropout_position = dropout_position
if self._dropout_position not in ("before_residual", "after_residual"):
raise ValueError(
......@@ -140,6 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
**common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability.
if self._apply_output_layer_norm:
self._output_layer_norm.append(
tf.keras.layers.LayerNormalization(
name="output_layer_norm_%d" % i,
......@@ -199,6 +202,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
# add.
if layer_input.dtype == tf.float32:
layer_output = tf.cast(layer_output, tf.float32)
if self._apply_output_layer_norm:
layer_output = self._output_layer_norm[i](layer_output + layer_input)
if self._dropout_position == "after_residual":
layer_output = self._output_dropout[i](layer_output)
......
......@@ -82,6 +82,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
feedforward_cfg=None,
dropout_rate=0.0,
attention_dropout_rate=0.0,
norm_first=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -96,6 +97,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_cls = attention_cls
self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
......@@ -246,11 +248,23 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
......@@ -261,8 +275,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output)
return layer_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment