Commit 0af0abe9 authored by Rami Al-Rfou's avatar Rami Al-Rfou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374890121
parent b1cf38dc
......@@ -51,6 +51,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
......@@ -83,6 +84,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super().__init__(**kwargs)
......@@ -111,6 +114,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
......@@ -121,9 +125,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerEncoderBlock expects a three-dimensional "
"input of shape [batch, sequence, width].")
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
......@@ -143,6 +147,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
......@@ -155,7 +160,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
......@@ -172,7 +177,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
......@@ -225,7 +230,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
}
base_config = super(TransformerEncoderBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
......@@ -296,6 +296,29 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
@parameterized.parameters({'attention_axes': None}, {'attention_axes': [1]},
{'attention_axes': [2]}, {'attention_axes': [1, 2]})
def test_several_attention_axes(self, attention_axes):
test_layer = TransformerEncoderBlock(
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
num_attention_heads=10,
attention_axes=attention_axes)
num_rows = 21
num_cols = 13
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(num_rows, num_cols, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment