Commit 886e188a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 473023517
parent 7d14c7ca
...@@ -22,7 +22,6 @@ import tensorflow as tf ...@@ -22,7 +22,6 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]] _Activation = Union[str, Callable[..., Any]]
...@@ -75,6 +74,10 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -75,6 +74,10 @@ class BertEncoderV2(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is layers. If set False, output of attention and intermediate dense layers is
normalized. normalized.
with_dense_inputs: Whether to accept dense embeddings as the input. with_dense_inputs: Whether to accept dense embeddings as the input.
return_attention_scores: Whether to add an additional output containing the
attention scores of all transformer layers. This will be a list of length
`num_layers`, and each element will be in the shape [batch_size,
num_attention_heads, seq_dim, seq_dim].
""" """
def __init__( def __init__(
...@@ -96,6 +99,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -96,6 +99,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False, norm_first: bool = False,
with_dense_inputs: bool = False, with_dense_inputs: bool = False,
return_attention_scores: bool = False,
**kwargs): **kwargs):
# Pops kwargs that are used in V1 implementation. # Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs: if 'dict_outputs' in kwargs:
...@@ -167,6 +171,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -167,6 +171,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
...@@ -195,6 +200,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -195,6 +200,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
'with_dense_inputs': with_dense_inputs, 'with_dense_inputs': with_dense_inputs,
'return_attention_scores': return_attention_scores,
} }
if with_dense_inputs: if with_dense_inputs:
self.inputs = dict( self.inputs = dict(
...@@ -249,19 +255,26 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -249,19 +255,26 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_mask = self._attention_mask_layer(embeddings, mask) attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = [] encoder_outputs = []
attention_outputs = []
x = embeddings x = embeddings
for layer in self._transformer_layers: for layer in self._transformer_layers:
x = layer([x, attention_mask]) x = layer([x, attention_mask])
if self._config['return_attention_scores']:
x, attention_scores = x
attention_outputs.append(attention_scores)
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :] first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor) pooled_output = self._pooler_layer(first_token_tensor)
return dict( output = dict(
sequence_output=encoder_outputs[-1], sequence_output=encoder_outputs[-1],
pooled_output=pooled_output, pooled_output=pooled_output,
encoder_outputs=encoder_outputs) encoder_outputs=encoder_outputs)
if self._config['return_attention_scores']:
output['attention_scores'] = attention_outputs
return output
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
...@@ -324,13 +337,13 @@ class BertEncoder(tf.keras.Model): ...@@ -324,13 +337,13 @@ class BertEncoder(tf.keras.Model):
This determines the variable shape for positional embeddings. This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take. type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer. feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer. feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output output_dropout: Dropout probability for the post-attention and output
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers attention_dropout: The dropout rate to use for the attention layers within
within the transformer layers. the transformer layers.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire target sequence of the last transformer layer. `None` means the entire
...@@ -341,16 +354,20 @@ class BertEncoder(tf.keras.Model): ...@@ -341,16 +354,20 @@ class BertEncoder(tf.keras.Model):
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size'). smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to embedding_layer: An optional Layer instance which will be called to generate
generate embeddings for the input word IDs. embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate norm_first: Whether to normalize inputs to attention and intermediate dense
dense layers. If set False, output of attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is
layers is normalized. normalized.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
return_all_encoder_outputs: Whether to output sequence embedding outputs of return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. Note: when the following `dict_outputs` all encoder transformer layers. Note: when the following `dict_outputs`
argument is True, all encoder outputs are always returned in the dict, argument is True, all encoder outputs are always returned in the dict,
keyed by `encoder_outputs`. keyed by `encoder_outputs`.
return_attention_scores: Whether to add an additional output containing the
attention scores of all transformer layers. This will be a list of length
`num_layers`, and each element will be in the shape [batch_size,
num_attention_heads, seq_dim, seq_dim].
""" """
def __init__( def __init__(
...@@ -372,6 +389,7 @@ class BertEncoder(tf.keras.Model): ...@@ -372,6 +389,7 @@ class BertEncoder(tf.keras.Model):
norm_first=False, norm_first=False,
dict_outputs=False, dict_outputs=False,
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
return_attention_scores: bool = False,
**kwargs): **kwargs):
if 'sequence_length' in kwargs: if 'sequence_length' in kwargs:
kwargs.pop('sequence_length') kwargs.pop('sequence_length')
...@@ -455,6 +473,7 @@ class BertEncoder(tf.keras.Model): ...@@ -455,6 +473,7 @@ class BertEncoder(tf.keras.Model):
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()(data, mask) attention_mask = layers.SelfAttentionMask()(data, mask)
encoder_outputs = [] encoder_outputs = []
attention_outputs = []
for i in range(num_layers): for i in range(num_layers):
if i == num_layers - 1 and output_range is not None: if i == num_layers - 1 and output_range is not None:
transformer_output_range = output_range transformer_output_range = output_range
...@@ -467,11 +486,15 @@ class BertEncoder(tf.keras.Model): ...@@ -467,11 +486,15 @@ class BertEncoder(tf.keras.Model):
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
transformer_layers.append(layer) transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
if return_attention_scores:
data, attention_scores = data
attention_outputs.append(attention_scores)
encoder_outputs.append(data) encoder_outputs.append(data)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
...@@ -491,6 +514,8 @@ class BertEncoder(tf.keras.Model): ...@@ -491,6 +514,8 @@ class BertEncoder(tf.keras.Model):
pooled_output=cls_output, pooled_output=cls_output,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
) )
if return_attention_scores:
outputs['attention_scores'] = attention_outputs
if dict_outputs: if dict_outputs:
super().__init__( super().__init__(
...@@ -503,6 +528,8 @@ class BertEncoder(tf.keras.Model): ...@@ -503,6 +528,8 @@ class BertEncoder(tf.keras.Model):
else: else:
sequence_output = outputs['sequence_output'] sequence_output = outputs['sequence_output']
outputs = [sequence_output, cls_output] outputs = [sequence_output, cls_output]
if return_attention_scores:
outputs.append(attention_outputs)
super().__init__( # pylint: disable=bad-super-call super().__init__( # pylint: disable=bad-super-call
inputs=[word_ids, mask, type_ids], inputs=[word_ids, mask, type_ids],
outputs=outputs, outputs=outputs,
...@@ -534,6 +561,7 @@ class BertEncoder(tf.keras.Model): ...@@ -534,6 +561,7 @@ class BertEncoder(tf.keras.Model):
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
'dict_outputs': dict_outputs, 'dict_outputs': dict_outputs,
'return_attention_scores': return_attention_scores,
} }
# pylint: disable=protected-access # pylint: disable=protected-access
self._setattr_tracking = False self._setattr_tracking = False
......
...@@ -106,6 +106,42 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -106,6 +106,42 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype) self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation_return_attention_scores(
self, encoder_cls):
hidden_size = 32
sequence_length = 21
num_attention_heads = 5
num_layers = 3
# Create a small BertEncoder for testing.
test_network = encoder_cls(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
return_attention_scores=True,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
all_attention_outputs = dict_outputs["attention_scores"]
expected_data_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
self.assertLen(all_attention_outputs, num_layers)
for data in all_attention_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_attention_outputs[-1].dtype)
@parameterized.named_parameters( @parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2), ("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder), ("encoder_v1", bert_encoder.BertEncoder),
...@@ -369,6 +405,34 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -369,6 +405,34 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype) self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
def test_attention_scores_output_network_creation(self):
hidden_size = 32
sequence_length = 21
num_attention_heads = 5
num_layers = 3
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
return_attention_scores=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
_, _, all_attention_outputs = test_network([word_ids, mask, type_ids])
expected_data_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
self.assertLen(all_attention_outputs, num_layers)
for data in all_attention_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_attention_outputs[-1].dtype)
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment