Commit 5a4c4e18 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 316054828
parent 9cdb5d72
...@@ -105,19 +105,27 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -105,19 +105,27 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
self._intermediate_dense = [] self._intermediate_dense = []
self._intermediate_activation_layers = []
self._gate_dense = [] self._gate_dense = []
self._output_dense = [] self._output_dense = []
self._output_dropout = [] self._output_dropout = []
self._output_layer_norm = [] self._output_layer_norm = []
activation_policy = tf.keras.mixed_precision.experimental.global_policy()
if activation_policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
activation_policy = tf.float32
for i in range(self._num_blocks): for i in range(self._num_blocks):
self._intermediate_dense.append( self._intermediate_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
activation=self._intermediate_activation,
name="intermediate_%d" % i, name="intermediate_%d" % i,
**common_kwargs)) **common_kwargs))
self._intermediate_activation_layers.append(tf.keras.layers.Activation(
self._intermediate_activation, dtype=activation_policy))
if self._use_gate: if self._use_gate:
self._gate_dense.append( self._gate_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.experimental.EinsumDense(
...@@ -180,6 +188,8 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -180,6 +188,8 @@ class GatedFeedforward(tf.keras.layers.Layer):
for i in range(self._num_blocks): for i in range(self._num_blocks):
layer_input = layer_output layer_input = layer_output
intermediate_output = self._intermediate_dense[i](layer_input) intermediate_output = self._intermediate_dense[i](layer_input)
intermediate_output = self._intermediate_activation_layers[i](
intermediate_output)
if self._use_gate: if self._use_gate:
gated_linear = self._gate_dense[i](layer_input) gated_linear = self._gate_dense[i](layer_input)
intermediate_output = intermediate_output * gated_linear intermediate_output = intermediate_output * gated_linear
......
...@@ -198,9 +198,16 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -198,9 +198,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
activation=self._intermediate_activation,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
...@@ -263,6 +270,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -263,6 +270,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_output) attention_output)
if self._feedforward_block is None: if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm # During mixed precision training, attention_output is from layer norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment