Commit 0af0abe9 authored by Rami Al-Rfou's avatar Rami Al-Rfou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374890121
parent b1cf38dc
...@@ -51,6 +51,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -51,6 +51,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_dropout=0.0, attention_dropout=0.0,
inner_dropout=0.0, inner_dropout=0.0,
attention_initializer=None, attention_initializer=None,
attention_axes=None,
**kwargs): **kwargs):
"""Initializes `TransformerEncoderBlock`. """Initializes `TransformerEncoderBlock`.
...@@ -83,6 +84,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -83,6 +84,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_initializer: Initializer for kernels of attention layers. If set attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for `None`, attention layers use kernel_initializer as initializer for
kernel. kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -111,6 +114,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -111,6 +114,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape): if isinstance(input_shape, tf.TensorShape):
...@@ -121,9 +125,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -121,9 +125,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
raise ValueError( raise ValueError(
"The type of input shape argument is not supported, got: %s" % "The type of input shape argument is not supported, got: %s" %
type(input_shape)) type(input_shape))
if len(input_tensor_shape.as_list()) != 3: einsum_equation = "abc,cd->abd"
raise ValueError("TransformerEncoderBlock expects a three-dimensional " if len(input_tensor_shape.as_list()) > 3:
"input of shape [batch, sequence, width].") einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1] hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
...@@ -143,6 +147,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -143,6 +147,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
...@@ -155,7 +160,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -155,7 +160,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -172,7 +177,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -172,7 +177,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
...@@ -225,7 +230,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -225,7 +230,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"inner_dropout": "inner_dropout":
self._inner_dropout, self._inner_dropout,
"attention_initializer": "attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer) tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
} }
base_config = super(TransformerEncoderBlock, self).get_config() base_config = super(TransformerEncoderBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -296,6 +296,29 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -296,6 +296,29 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
encoder_block_config) encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config()) self.assertEqual(encoder_block_config, new_encoder_block.get_config())
@parameterized.parameters({'attention_axes': None}, {'attention_axes': [1]},
{'attention_axes': [2]}, {'attention_axes': [1, 2]})
def test_several_attention_axes(self, attention_axes):
test_layer = TransformerEncoderBlock(
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
num_attention_heads=10,
attention_axes=attention_axes)
num_rows = 21
num_cols = 13
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(num_rows, num_cols, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment