Commit 52b16a1a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332806032
parent 71a2fc91
...@@ -161,8 +161,9 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -161,8 +161,9 @@ class BertPretrainerV2(tf.keras.Model):
name: The name of the model. name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary. dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names. head names, and also outputs from `encoder_network`, keyed by
`pooled_output`, `sequence_output` and `encoder_outputs` (if any).
""" """
def __init__( def __init__(
...@@ -180,21 +181,32 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -180,21 +181,32 @@ class BertPretrainerV2(tf.keras.Model):
'classification_heads': classification_heads, 'classification_heads': classification_heads,
'name': name, 'name': name,
} }
self.encoder_network = encoder_network self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs) inputs = copy.copy(self.encoder_network.inputs)
outputs = self.encoder_network(inputs) outputs = dict()
if isinstance(outputs, list): encoder_network_outputs = self.encoder_network(inputs)
sequence_output = outputs[0] if isinstance(encoder_network_outputs, list):
outputs['pooled_output'] = encoder_network_outputs[1]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if isinstance(encoder_network_outputs[0], list):
outputs['encoder_outputs'] = encoder_network_outputs[0]
outputs['sequence_output'] = encoder_network_outputs[0][-1]
else: else:
sequence_output = outputs['sequence_output'] outputs['sequence_output'] = encoder_network_outputs[0]
elif isinstance(encoder_network_outputs, dict):
outputs = encoder_network_outputs
else:
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or [] self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len( if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads): self.classification_heads):
raise ValueError('Classification heads should have unique names.') raise ValueError('Classification heads should have unique names.')
outputs = dict()
self.masked_lm = layers.MaskedLM( self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(), embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation, activation=mlm_activation,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for BERT pretrainer model.""" """Tests for BERT pretrainer model."""
import itertools
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
...@@ -108,16 +109,23 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -108,16 +109,23 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(), self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config()) new_bert_trainer_model.get_config())
@parameterized.parameters(True, False) @parameterized.parameters(itertools.product(
def test_bert_pretrainerv2(self, dict_outputs): (False, True),
(False, True),
))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length, max_sequence_length=sequence_length,
return_all_encoder_outputs=return_all_encoder_outputs,
dict_outputs=dict_outputs) dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
...@@ -133,10 +141,28 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -133,10 +141,28 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
if has_encoder_outputs:
self.assertSameElements(
outputs.keys(),
['sequence_output', 'pooled_output', 'lm_output', 'encoder_outputs'])
self.assertLen(outputs['encoder_outputs'], num_layers)
else:
self.assertSameElements(outputs.keys(),
['sequence_output', 'pooled_output', 'lm_output'])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list()) self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape,
outputs['sequence_output'].shape.as_list())
expected_pooled_output_shape = [None, hidden_size]
self.assertAllEqual(expected_pooled_output_shape,
outputs['pooled_output'].shape.as_list())
def test_v2_serialize_deserialize(self): def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment