Commit b0a37140 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix dropout rate bug.

Fix attention dropout rate bug. The output_dropout rate was used for attention dropout by mistake.

PiperOrigin-RevId: 462470287
parent 6d6e881a
......@@ -132,9 +132,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
......@@ -198,7 +196,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
num_heads=self._num_heads,
key_dim=self._key_dim,
value_dim=self._value_dim,
dropout=self._attention_dropout,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
......@@ -206,7 +204,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=self._output_last_dim,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
......@@ -250,7 +249,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
self._output_dropout = tf.keras.layers.Dropout(
rate=self._output_dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
......
......@@ -632,6 +632,41 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
@parameterized.parameters({'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.3},
{'output_dropout': 0.0,
'attention_dropout': 0.2,
'inner_dropout': 0.3},
{'output_dropout': 0.1,
'attention_dropout': 0.0,
'inner_dropout': 0.3},
{'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.0})
def test_dropout_config(self,
output_dropout,
attention_dropout,
inner_dropout):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=32,
inner_activation='relu',
output_dropout=output_dropout,
attention_dropout=attention_dropout,
inner_dropout=inner_dropout)
seq_len = 21
hidden_size = 512
input_tensor = tf.keras.Input(shape=(seq_len, hidden_size))
_ = test_layer(input_tensor)
true_output_dropout = test_layer._output_dropout.get_config()['rate']
true_attention_dropout = test_layer._attention_dropout.get_config()['rate']
true_inner_dropout = test_layer._inner_dropout_layer.get_config()['rate']
self.assertEqual(true_output_dropout, output_dropout)
self.assertEqual(true_attention_dropout, attention_dropout)
self.assertEqual(true_inner_dropout, inner_dropout)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment