Commit 886e188a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 473023517
parent 7d14c7ca
......@@ -22,7 +22,6 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]]
......@@ -75,6 +74,10 @@ class BertEncoderV2(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
normalized.
with_dense_inputs: Whether to accept dense embeddings as the input.
return_attention_scores: Whether to add an additional output containing the
attention scores of all transformer layers. This will be a list of length
`num_layers`, and each element will be in the shape [batch_size,
num_attention_heads, seq_dim, seq_dim].
"""
def __init__(
......@@ -96,6 +99,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
with_dense_inputs: bool = False,
return_attention_scores: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
......@@ -167,6 +171,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
......@@ -195,6 +200,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'with_dense_inputs': with_dense_inputs,
'return_attention_scores': return_attention_scores,
}
if with_dense_inputs:
self.inputs = dict(
......@@ -249,19 +255,26 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
attention_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
if self._config['return_attention_scores']:
x, attention_scores = x
attention_outputs.append(attention_scores)
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
output = dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
if self._config['return_attention_scores']:
output['attention_scores'] = attention_outputs
return output
def get_embedding_table(self):
return self._embedding_layer.embeddings
......@@ -324,13 +337,13 @@ class BertEncoder(tf.keras.Model):
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers
within the transformer layers.
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
......@@ -341,16 +354,20 @@ class BertEncoder(tf.keras.Model):
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
dict_outputs: Whether to use a dictionary as the model outputs.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. Note: when the following `dict_outputs`
argument is True, all encoder outputs are always returned in the dict,
keyed by `encoder_outputs`.
return_attention_scores: Whether to add an additional output containing the
attention scores of all transformer layers. This will be a list of length
`num_layers`, and each element will be in the shape [batch_size,
num_attention_heads, seq_dim, seq_dim].
"""
def __init__(
......@@ -372,6 +389,7 @@ class BertEncoder(tf.keras.Model):
norm_first=False,
dict_outputs=False,
return_all_encoder_outputs=False,
return_attention_scores: bool = False,
**kwargs):
if 'sequence_length' in kwargs:
kwargs.pop('sequence_length')
......@@ -455,6 +473,7 @@ class BertEncoder(tf.keras.Model):
data = embeddings
attention_mask = layers.SelfAttentionMask()(data, mask)
encoder_outputs = []
attention_outputs = []
for i in range(num_layers):
if i == num_layers - 1 and output_range is not None:
transformer_output_range = output_range
......@@ -467,11 +486,15 @@ class BertEncoder(tf.keras.Model):
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=transformer_output_range,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
transformer_layers.append(layer)
data = layer([data, attention_mask])
if return_attention_scores:
data, attention_scores = data
attention_outputs.append(attention_scores)
encoder_outputs.append(data)
last_encoder_output = encoder_outputs[-1]
......@@ -491,6 +514,8 @@ class BertEncoder(tf.keras.Model):
pooled_output=cls_output,
encoder_outputs=encoder_outputs,
)
if return_attention_scores:
outputs['attention_scores'] = attention_outputs
if dict_outputs:
super().__init__(
......@@ -503,6 +528,8 @@ class BertEncoder(tf.keras.Model):
else:
sequence_output = outputs['sequence_output']
outputs = [sequence_output, cls_output]
if return_attention_scores:
outputs.append(attention_outputs)
super().__init__( # pylint: disable=bad-super-call
inputs=[word_ids, mask, type_ids],
outputs=outputs,
......@@ -534,6 +561,7 @@ class BertEncoder(tf.keras.Model):
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'dict_outputs': dict_outputs,
'return_attention_scores': return_attention_scores,
}
# pylint: disable=protected-access
self._setattr_tracking = False
......
......@@ -106,6 +106,42 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation_return_attention_scores(
self, encoder_cls):
hidden_size = 32
sequence_length = 21
num_attention_heads = 5
num_layers = 3
# Create a small BertEncoder for testing.
test_network = encoder_cls(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
return_attention_scores=True,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
all_attention_outputs = dict_outputs["attention_scores"]
expected_data_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
self.assertLen(all_attention_outputs, num_layers)
for data in all_attention_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_attention_outputs[-1].dtype)
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
......@@ -369,6 +405,34 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_attention_scores_output_network_creation(self):
hidden_size = 32
sequence_length = 21
num_attention_heads = 5
num_layers = 3
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
return_attention_scores=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
_, _, all_attention_outputs = test_network([word_ids, mask, type_ids])
expected_data_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
self.assertLen(all_attention_outputs, num_layers)
for data in all_attention_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_attention_outputs[-1].dtype)
def test_network_creation_with_float16_dtype(self):
hidden_size = 32
sequence_length = 21
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment