Commit eaff981c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add norm_first to Transformer Scaffold; Add an option in gated_feedword to...

Add norm_first to Transformer Scaffold; Add an option in gated_feedword to disable the output layer_norm.

PiperOrigin-RevId: 333591020
parent ee4b011b
...@@ -59,6 +59,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -59,6 +59,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
intermediate_activation, intermediate_activation,
dropout, dropout,
use_gate=True, use_gate=True,
apply_output_layer_norm=True,
num_blocks=1, num_blocks=1,
dropout_position="before_residual", dropout_position="before_residual",
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
...@@ -75,6 +76,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -75,6 +76,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._dropout = dropout self._dropout = dropout
self._use_gate = use_gate self._use_gate = use_gate
self._num_blocks = num_blocks self._num_blocks = num_blocks
self._apply_output_layer_norm = apply_output_layer_norm
self._dropout_position = dropout_position self._dropout_position = dropout_position
if self._dropout_position not in ("before_residual", "after_residual"): if self._dropout_position not in ("before_residual", "after_residual"):
raise ValueError( raise ValueError(
...@@ -140,6 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -140,6 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
**common_kwargs)) **common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout)) self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
if self._apply_output_layer_norm:
self._output_layer_norm.append( self._output_layer_norm.append(
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="output_layer_norm_%d" % i, name="output_layer_norm_%d" % i,
...@@ -199,6 +202,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -199,6 +202,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
# add. # add.
if layer_input.dtype == tf.float32: if layer_input.dtype == tf.float32:
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
if self._apply_output_layer_norm:
layer_output = self._output_layer_norm[i](layer_output + layer_input) layer_output = self._output_layer_norm[i](layer_output + layer_input)
if self._dropout_position == "after_residual": if self._dropout_position == "after_residual":
layer_output = self._output_dropout[i](layer_output) layer_output = self._output_dropout[i](layer_output)
......
...@@ -82,6 +82,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -82,6 +82,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
feedforward_cfg=None, feedforward_cfg=None,
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
norm_first=False,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
kernel_regularizer=None, kernel_regularizer=None,
...@@ -96,6 +97,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -96,6 +97,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_cls = attention_cls self._attention_cls = attention_cls
self._feedforward_cls = feedforward_cls self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation self._intermediate_activation = intermediate_activation
...@@ -246,11 +248,23 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -246,11 +248,23 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask) query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
if self._feedforward_block is None: if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
...@@ -261,8 +275,17 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -261,8 +275,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent # and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output) layer_output = self._output_layer_norm(layer_output + attention_output)
else: else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output) layer_output = self._feedforward_block(attention_output)
return layer_output return layer_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment