Commit 3ce3a5c5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate all TF-NLP tasks to consume models with dictionary outputs and...

Migrate all TF-NLP tasks to consume models with dictionary outputs and encoders with dictionary outputs.

PiperOrigin-RevId: 332740616
parent e04dafd0
...@@ -185,7 +185,8 @@ def build_encoder( ...@@ -185,7 +185,8 @@ def build_encoder(
pooled_output_dim=encoder_cfg.hidden_size, pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs) return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_type == "mobilebert": if encoder_type == "mobilebert":
...@@ -221,7 +222,8 @@ def build_encoder( ...@@ -221,7 +222,8 @@ def build_encoder(
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range)) stddev=encoder_cfg.initializer_range),
dict_outputs=True)
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
...@@ -240,4 +242,5 @@ def build_encoder( ...@@ -240,4 +242,5 @@ def build_encoder(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size, embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer, embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs) return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
...@@ -84,11 +84,16 @@ class DualEncoder(tf.keras.Model): ...@@ -84,11 +84,16 @@ class DualEncoder(tf.keras.Model):
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids] left_inputs = [left_word_ids, left_mask, left_type_ids]
left_sequence_output, left_encoded = network(left_inputs) left_outputs = network(left_inputs)
if isinstance(left_outputs, list):
left_sequence_output, left_encoded = left_outputs
else:
left_sequence_output = left_outputs['sequence_output']
left_encoded = left_outputs['pooled_output']
if normalize: if normalize:
left_encoded = tf.keras.layers.Lambda( left_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(left_encoded) lambda x: tf.nn.l2_normalize(x, axis=1))(
left_encoded)
if output == 'logits': if output == 'logits':
right_word_ids = tf.keras.layers.Input( right_word_ids = tf.keras.layers.Input(
...@@ -99,33 +104,40 @@ class DualEncoder(tf.keras.Model): ...@@ -99,33 +104,40 @@ class DualEncoder(tf.keras.Model):
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids] right_inputs = [right_word_ids, right_mask, right_type_ids]
_, right_encoded = network(right_inputs) right_outputs = network(right_inputs)
if isinstance(right_outputs, list):
_, right_encoded = right_outputs
else:
right_encoded = right_outputs['pooled_output']
if normalize: if normalize:
right_encoded = tf.keras.layers.Lambda( right_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(right_encoded) lambda x: tf.nn.l2_normalize(x, axis=1))(
right_encoded)
dot_products = layers.MatMulWithMargin(logit_scale=logit_scale,
logit_margin=logit_margin, dot_products = layers.MatMulWithMargin(
name='dot_product') logit_scale=logit_scale,
logit_margin=logit_margin,
inputs = [left_word_ids, left_mask, left_type_ids, right_word_ids, name='dot_product')
right_mask, right_type_ids]
inputs = [
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
]
left_logits, right_logits = dot_products(left_encoded, right_encoded) left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = [left_logits, right_logits] outputs = dict(left_logits=left_logits, right_logits=right_logits)
elif output == 'predictions': elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids] inputs = [left_word_ids, left_mask, left_type_ids]
# To keep consistent with legacy BERT hub modules, the outputs are # To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output". # "pooled_output" and "sequence_output".
outputs = [left_encoded, left_sequence_output] outputs = dict(
sequence_output=left_sequence_output, pooled_output=left_encoded)
else: else:
raise ValueError('output type %s is not supported' % output) raise ValueError('output type %s is not supported' % output)
super(DualEncoder, self).__init__( super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets. # Set _self_setattr_tracking to True so it can be exported with assets.
self._self_setattr_tracking = True self._self_setattr_tracking = True
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
# ============================================================================== # ==============================================================================
"""Tests for dual encoder network.""" """Tests for dual encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
...@@ -42,7 +38,8 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -42,7 +38,8 @@ class DualEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length) sequence_length=sequence_length,
dict_outputs=True)
# Create a dual encoder model with the created network. # Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder( dual_encoder_model = dual_encoder.DualEncoder(
...@@ -59,21 +56,19 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -59,21 +56,19 @@ class DualEncoderTest(keras_parameterized.TestCase):
if output == 'logits': if output == 'logits':
outputs = dual_encoder_model([ outputs = dual_encoder_model([
left_word_ids, left_mask, left_type_ids, left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_word_ids, right_mask, right_type_ids]) right_type_ids
])
left_encoded, _ = outputs _ = outputs['left_logits']
elif output == 'predictions': elif output == 'predictions':
left_encoded, left_sequence_output = dual_encoder_model([ outputs = dual_encoder_model([left_word_ids, left_mask, left_type_ids])
left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
expected_sequence_shape = [None, sequence_length, 768] expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape, self.assertAllEqual(expected_sequence_shape,
left_sequence_output.shape.as_list()) outputs['sequence_output'].shape.as_list())
left_encoded = outputs['pooled_output']
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions')) @parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output): def test_dual_encoder_tensor_call(self, hidden_size, output):
......
...@@ -139,14 +139,11 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -139,14 +139,11 @@ class ElectraPretrainer(tf.keras.Model):
masked_lm_positions = inputs['masked_lm_positions'] masked_lm_positions = inputs['masked_lm_positions']
### Generator ### ### Generator ###
sequence_output, cls_output = self.generator_network( sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers. # The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list): if isinstance(sequence_output, list):
sequence_output = sequence_output[-1] sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions) lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
sentence_outputs = self.classification(sequence_output) sentence_outputs = self.classification(sequence_output)
...@@ -157,10 +154,10 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -157,10 +154,10 @@ class ElectraPretrainer(tf.keras.Model):
### Discriminator ### ### Discriminator ###
disc_input = fake_data['inputs'] disc_input = fake_data['inputs']
disc_label = fake_data['is_fake_tokens'] disc_label = fake_data['is_fake_tokens']
disc_sequence_output, _ = self.discriminator_network([ disc_sequence_output = self.discriminator_network([
disc_input['input_word_ids'], disc_input['input_mask'], disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids'] disc_input['input_type_ids']
]) ])['sequence_output']
# The discriminator encoder network may get outputs from all layers. # The discriminator encoder network may get outputs from all layers.
if isinstance(disc_sequence_output, list): if isinstance(disc_sequence_output, list):
......
...@@ -38,11 +38,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -38,11 +38,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
test_generator_network = networks.BertEncoder( test_generator_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length,
dict_outputs=True)
test_discriminator_network = networks.BertEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length,
dict_outputs=True)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -92,9 +94,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -92,9 +94,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. (Here, we # Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.) # use a short sequence_length for convenience.)
test_generator_network = networks.BertEncoder( test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
test_discriminator_network = networks.BertEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer( eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment