Commit 52b16a1a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332806032
parent 71a2fc91
......@@ -161,8 +161,9 @@ class BertPretrainerV2(tf.keras.Model):
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by
head names.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
`pooled_output`, `sequence_output` and `encoder_outputs` (if any).
"""
def __init__(
......@@ -180,21 +181,32 @@ class BertPretrainerV2(tf.keras.Model):
'classification_heads': classification_heads,
'name': name,
}
self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs)
outputs = self.encoder_network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
outputs = dict()
encoder_network_outputs = self.encoder_network(inputs)
if isinstance(encoder_network_outputs, list):
outputs['pooled_output'] = encoder_network_outputs[1]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if isinstance(encoder_network_outputs[0], list):
outputs['encoder_outputs'] = encoder_network_outputs[0]
outputs['sequence_output'] = encoder_network_outputs[0][-1]
else:
outputs['sequence_output'] = encoder_network_outputs[0]
elif isinstance(encoder_network_outputs, dict):
outputs = encoder_network_outputs
else:
sequence_output = outputs['sequence_output']
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
outputs = dict()
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests for BERT pretrainer model."""
import itertools
from absl.testing import parameterized
import tensorflow as tf
......@@ -108,16 +109,23 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config())
@parameterized.parameters(True, False)
def test_bert_pretrainerv2(self, dict_outputs):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
return_all_encoder_outputs=return_all_encoder_outputs,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
......@@ -133,10 +141,28 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
if has_encoder_outputs:
self.assertSameElements(
outputs.keys(),
['sequence_output', 'pooled_output', 'lm_output', 'encoder_outputs'])
self.assertLen(outputs['encoder_outputs'], num_layers)
else:
self.assertSameElements(outputs.keys(),
['sequence_output', 'pooled_output', 'lm_output'])
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape,
outputs['sequence_output'].shape.as_list())
expected_pooled_output_shape = [None, hidden_size]
self.assertAllEqual(expected_pooled_output_shape,
outputs['pooled_output'].shape.as_list())
def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment