Commit dc03c043 authored by xinliupitt's avatar xinliupitt
Browse files

intermediate dropout

parent e93afea8
......@@ -55,6 +55,10 @@ class Transformer(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. If
larger than 0.0, intermediate_dropout_layer is created and used after
intermediate_activation_layer. Otherwise, intermediate_dropout_layer is
None.
"""
def __init__(self,
......@@ -74,6 +78,7 @@ class Transformer(tf.keras.layers.Layer):
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
**kwargs):
super(Transformer, self).__init__(**kwargs)
......@@ -93,6 +98,7 @@ class Transformer(tf.keras.layers.Layer):
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
......@@ -155,6 +161,11 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
if self._intermediate_dropout > 0.0:
self.intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
else:
self.intermediate_dropout_layer = None
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
......@@ -204,7 +215,9 @@ class Transformer(tf.keras.layers.Layer):
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout
}
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -238,6 +251,8 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
if self.intermediate_dropout_layer:
intermediate_output = self.intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
......@@ -291,6 +306,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. If
larger than 0.0, intermediate_dropout_layer is created and used after
intermediate_activation_layer. Otherwise, intermediate_dropout_layer is
None.
"""
def __init__(self,
......@@ -310,6 +329,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
......@@ -329,6 +349,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
......@@ -401,6 +422,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
if self._intermediate_dropout > 0.0:
self.intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
else:
self.intermediate_dropout_layer = None
self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
......@@ -445,7 +471,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -508,6 +536,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
if self.intermediate_dropout_layer:
intermediate_output = self.intermediate_dropout_layer(intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
if self._norm_first:
......
......@@ -230,7 +230,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
norm_epsilon=1e-6,
intermediate_dropout=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
......@@ -248,7 +249,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
norm_epsilon=1e-6,
intermediate_dropout=0.1)
encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config(
encoder_block_config)
......@@ -299,7 +301,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
norm_epsilon=1e-6,
intermediate_dropout=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
......@@ -317,7 +320,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
norm_epsilon=1e-6,
intermediate_dropout=0.1)
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment