Commit 7d14c7ca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 472919387
parent c0c99b23
......@@ -61,6 +61,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
value_dim=None,
output_last_dim=None,
diff_q_kv_att_layer_norm=False,
return_attention_scores=False,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
......@@ -117,13 +118,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
`None`, we use the first `input_shape`'s last dim.
value_dim: `value_dim` for the `tf.keras.layers.MultiHeadAttention`.
output_last_dim: Final dimension of the output of this module. This also
dictates the value for the final dimension of the
multi-head-attention. When it's `None`, we use, in order of decreasing
precedence, `key_dim` * `num_heads` or the first `input_shape`'s last
dim as the output's last dim.
dictates the value for the final dimension of the multi-head-attention.
When it's `None`, we use, in order of decreasing precedence, `key_dim` *
`num_heads` or the first `input_shape`'s last dim as the output's last
dim.
diff_q_kv_att_layer_norm: If `True`, create a separate attention layer
norm layer for query and key-value if `norm_first` is `True`. Invalid
to set to `True` if `norm_first` is `False`.
norm layer for query and key-value if `norm_first` is `True`. Invalid to
set to `True` if `norm_first` is `False`.
return_attention_scores: If `True`, the output of this layer will be a
tuple and additionally contain the attention scores in the shape of
`[batch_size, num_attention_heads, seq_dim, seq_dim]`.
**kwargs: keyword arguments.
"""
util.filter_kwargs(kwargs)
......@@ -156,6 +160,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._value_dim = value_dim
self._output_last_dim = output_last_dim
self._diff_q_kv_att_layer_norm = diff_q_kv_att_layer_norm
self._return_attention_scores = return_attention_scores
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
......@@ -303,7 +308,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
"attention_axes":
self._attention_axes,
"use_query_residual":
self._use_query_residual,
"key_dim":
......@@ -322,13 +328,11 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
inputs: a single tensor or a list of tensors. `input tensor` as the single
sequence of embeddings. [`input tensor`, `attention mask`] to have the
additional attention mask. [`query tensor`, `key value tensor`,
`attention mask`] to have separate input streams for the query, and
key/value to the multi-head attention.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced. If you
would like to have no change to the model training, it is better to only
......@@ -370,8 +374,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
if self._return_attention_scores:
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
return_attention_scores=True)
else:
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
......@@ -395,9 +407,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
layer_output = source_attention_output + layer_output
else:
# During mixed precision training, layer norm output is always fp32 for
# now. Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
if self._return_attention_scores:
return layer_output, attention_scores
else:
return layer_output
......@@ -23,8 +23,7 @@ from official.nlp.modeling.layers.transformer_encoder_block import TransformerEn
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(
('base', TransformerEncoderBlock))
@parameterized.named_parameters(('base', TransformerEncoderBlock))
class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def tearDown(self):
......@@ -130,8 +129,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
norm_first=True)
sequence_length = 21
width = 80
......@@ -156,8 +157,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
norm_first=True)
sequence_length = 21
width = 80
......@@ -259,8 +262,8 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
class TransformerEncoderBlockLayerTestWithoutParams(
keras_parameterized.TestCase):
class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase
):
def tearDown(self):
super(TransformerEncoderBlockLayerTestWithoutParams, self).tearDown()
......@@ -280,12 +283,10 @@ class TransformerEncoderBlockLayerTestWithoutParams(
with self.assertRaises(tf.errors.InvalidArgumentError):
test_layer(inputs)
@parameterized.named_parameters(
('output_range_not_none', 2),
('output_range_none', None))
@parameterized.named_parameters(('output_range_not_none', 2),
('output_range_none', None))
def test_needs_diff_q_kv_att_layer_norm_to_be_true_for_diff_q_and_kv_dims(
self,
output_range):
self, output_range):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
......@@ -309,9 +310,8 @@ class TransformerEncoderBlockLayerTestWithoutParams(
# Forward path.
test_layer(inputs)
@parameterized.named_parameters(
('norm_first_is_true', True),
('norm_first_is_false', False))
@parameterized.named_parameters(('norm_first_is_true', True),
('norm_first_is_false', False))
def test_use_query_residual_false_removes_add_op(self, norm_first):
graph_with_res = tf.Graph()
with graph_with_res.as_default():
......@@ -344,15 +344,10 @@ class TransformerEncoderBlockLayerTestWithoutParams(
list(graph_with_res_names - graph_without_res_names)[0])
self.assertEmpty(graph_without_res_names - graph_with_res_names)
@parameterized.named_parameters(
('key_dim_is_none', None, 128, 2, 128 // 2),
('key_dim_is_not_none', 30, 128, 2, 30))
def test_key_dim(
self,
key_dim,
q_tensor_last_dim,
some_num_attention_heads,
expected):
@parameterized.named_parameters(('key_dim_is_none', None, 128, 2, 128 // 2),
('key_dim_is_not_none', 30, 128, 2, 30))
def test_key_dim(self, key_dim, q_tensor_last_dim, some_num_attention_heads,
expected):
some_inner_dim = 32
some_inner_activation = 'relu'
test_layer = TransformerEncoderBlock(
......@@ -366,28 +361,16 @@ class TransformerEncoderBlockLayerTestWithoutParams(
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
test_layer([q_tensor, kv_tensor, dummy_mask])
self.assertEqual(
expected,
test_layer._attention_layer.get_config()['key_dim'])
self.assertEqual(expected,
test_layer._attention_layer.get_config()['key_dim'])
@parameterized.named_parameters(
('output_last_dim_is_none_use_query_residual_false',
False,
None,
128,
128),
('output_last_dim_is_none_use_query_residual_true',
True,
None,
128,
('output_last_dim_is_none_use_query_residual_false', False, None, 128,
128),
('output_last_dim_is_none_use_query_residual_true', True, None, 128, 128),
('output_last_dim_is_not_none', False, 30, 128, 30))
def test_output_last_dim(
self,
use_query_residual,
output_last_dim,
q_tensor_last_dim,
expected):
def test_output_last_dim(self, use_query_residual, output_last_dim,
q_tensor_last_dim, expected):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
......@@ -407,15 +390,10 @@ class TransformerEncoderBlockLayerTestWithoutParams(
self.assertEqual(output.numpy().shape[-1], expected)
@parameterized.named_parameters(
('value_dim_is_none', None, 128, 2, 128 // 2),
('value_dim_is_not_none', 30, 128, 2, 30))
def test_value_dim(
self,
value_dim,
q_tensor_last_dim,
some_num_attention_heads,
expected):
@parameterized.named_parameters(('value_dim_is_none', None, 128, 2, 128 // 2),
('value_dim_is_not_none', 30, 128, 2, 30))
def test_value_dim(self, value_dim, q_tensor_last_dim,
some_num_attention_heads, expected):
some_inner_dim = 32
some_inner_activation = 'relu'
test_layer = TransformerEncoderBlock(
......@@ -429,9 +407,8 @@ class TransformerEncoderBlockLayerTestWithoutParams(
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
test_layer([q_tensor, kv_tensor, dummy_mask])
self.assertEqual(
expected,
test_layer._attention_layer.get_config()['value_dim'])
self.assertEqual(expected,
test_layer._attention_layer.get_config()['value_dim'])
@keras_parameterized.run_all_keras_modes
......@@ -638,21 +615,25 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
@parameterized.parameters({'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.3},
{'output_dropout': 0.0,
'attention_dropout': 0.2,
'inner_dropout': 0.3},
{'output_dropout': 0.1,
'attention_dropout': 0.0,
'inner_dropout': 0.3},
{'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.0})
def test_dropout_config(self,
output_dropout,
attention_dropout,
@parameterized.parameters(
{
'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.3
}, {
'output_dropout': 0.0,
'attention_dropout': 0.2,
'inner_dropout': 0.3
}, {
'output_dropout': 0.1,
'attention_dropout': 0.0,
'inner_dropout': 0.3
}, {
'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.0
})
def test_dropout_config(self, output_dropout, attention_dropout,
inner_dropout):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
......@@ -673,6 +654,49 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
self.assertEqual(true_attention_dropout, attention_dropout)
self.assertEqual(true_inner_dropout, inner_dropout)
@parameterized.named_parameters(
(
'return_attention_scores_is_false',
False,
),
(
'return_attention_scores_is_true',
True,
),
)
def test_return_attention_scores(self, return_attention_scores):
num_attention_heads = 7
sequence_length = 21
width = 80
test_layer = TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=2048,
inner_activation='relu',
return_attention_scores=return_attention_scores)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
expected_layer_output_shape = [None, sequence_length, width]
expected_attention_scores_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
if return_attention_scores:
self.assertIsInstance(output_tensor, tuple)
self.assertEqual(len(output_tensor), 2)
# First is the standard output.
self.assertEqual(output_tensor[0].shape.as_list(),
expected_layer_output_shape)
# Second is the attention scores.
self.assertEqual(output_tensor[1].shape.as_list(),
expected_attention_scores_shape)
else:
# Only the standard layer output.
self.assertEqual(output_tensor.shape.as_list(),
expected_layer_output_shape)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment