Commit d4bb3055 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Update nlp.modeling.layers.ReZeroTransformer to be consistent with nlp.modeling.layers.Transformer

PiperOrigin-RevId: 315584374
parent 465354df
...@@ -143,8 +143,14 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -143,8 +143,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="intermediate") name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment