"vscode:/vscode.git/clone" did not exist on "05d765191a8dd851db794f72608c1c1e143942c6"
Commit 3ce3a5c5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate all TF-NLP tasks to consume models with dictionary outputs and...

Migrate all TF-NLP tasks to consume models with dictionary outputs and encoders with dictionary outputs.

PiperOrigin-RevId: 332740616
parent e04dafd0
......@@ -185,7 +185,8 @@ def build_encoder(
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs)
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return encoder_cls(**kwargs)
if encoder_type == "mobilebert":
......@@ -221,7 +222,8 @@ def build_encoder(
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range))
stddev=encoder_cfg.initializer_range),
dict_outputs=True)
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
......@@ -240,4 +242,5 @@ def build_encoder(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs)
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
......@@ -84,11 +84,16 @@ class DualEncoder(tf.keras.Model):
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
left_sequence_output, left_encoded = network(left_inputs)
left_outputs = network(left_inputs)
if isinstance(left_outputs, list):
left_sequence_output, left_encoded = left_outputs
else:
left_sequence_output = left_outputs['sequence_output']
left_encoded = left_outputs['pooled_output']
if normalize:
left_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(left_encoded)
lambda x: tf.nn.l2_normalize(x, axis=1))(
left_encoded)
if output == 'logits':
right_word_ids = tf.keras.layers.Input(
......@@ -99,33 +104,40 @@ class DualEncoder(tf.keras.Model):
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids]
_, right_encoded = network(right_inputs)
right_outputs = network(right_inputs)
if isinstance(right_outputs, list):
_, right_encoded = right_outputs
else:
right_encoded = right_outputs['pooled_output']
if normalize:
right_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(right_encoded)
lambda x: tf.nn.l2_normalize(x, axis=1))(
right_encoded)
dot_products = layers.MatMulWithMargin(logit_scale=logit_scale,
dot_products = layers.MatMulWithMargin(
logit_scale=logit_scale,
logit_margin=logit_margin,
name='dot_product')
inputs = [left_word_ids, left_mask, left_type_ids, right_word_ids,
right_mask, right_type_ids]
inputs = [
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
]
left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = [left_logits, right_logits]
outputs = dict(left_logits=left_logits, right_logits=right_logits)
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = [left_encoded, left_sequence_output]
outputs = dict(
sequence_output=left_sequence_output, pooled_output=left_encoded)
else:
raise ValueError('output type %s is not supported' % output)
super(DualEncoder, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets.
self._self_setattr_tracking = True
......
......@@ -14,10 +14,6 @@
# ==============================================================================
"""Tests for dual encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
......@@ -42,7 +38,8 @@ class DualEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size,
num_layers=2,
hidden_size=hidden_size,
sequence_length=sequence_length)
sequence_length=sequence_length,
dict_outputs=True)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
......@@ -59,21 +56,19 @@ class DualEncoderTest(keras_parameterized.TestCase):
if output == 'logits':
outputs = dual_encoder_model([
left_word_ids, left_mask, left_type_ids,
right_word_ids, right_mask, right_type_ids])
left_encoded, _ = outputs
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
])
_ = outputs['left_logits']
elif output == 'predictions':
left_encoded, left_sequence_output = dual_encoder_model([
left_word_ids, left_mask, left_type_ids])
outputs = dual_encoder_model([left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape.
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape,
left_sequence_output.shape.as_list())
outputs['sequence_output'].shape.as_list())
left_encoded = outputs['pooled_output']
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output):
......
......@@ -139,14 +139,11 @@ class ElectraPretrainer(tf.keras.Model):
masked_lm_positions = inputs['masked_lm_positions']
### Generator ###
sequence_output, cls_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])
sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
sentence_outputs = self.classification(sequence_output)
......@@ -157,10 +154,10 @@ class ElectraPretrainer(tf.keras.Model):
### Discriminator ###
disc_input = fake_data['inputs']
disc_label = fake_data['is_fake_tokens']
disc_sequence_output, _ = self.discriminator_network([
disc_sequence_output = self.discriminator_network([
disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids']
])
])['sequence_output']
# The discriminator encoder network may get outputs from all layers.
if isinstance(disc_sequence_output, list):
......
......@@ -38,11 +38,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
test_generator_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
max_sequence_length=sequence_length,
dict_outputs=True)
test_discriminator_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
max_sequence_length=sequence_length,
dict_outputs=True)
# Create a ELECTRA trainer with the created network.
num_classes = 3
......@@ -92,9 +94,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.)
test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
# Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment