Commit b47eb2a2 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add dict outputs to BertEncoder and EncoderScaffold

PiperOrigin-RevId: 330817736
parent b47e5fe9
......@@ -72,7 +72,11 @@ class BertClassifier(tf.keras.Model):
if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs)
outputs = network(inputs)
if isinstance(outputs, list):
cls_output = outputs[1]
else:
cls_output = outputs['pooled_output']
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification(
......@@ -83,7 +87,11 @@ class BertClassifier(tf.keras.Model):
name='sentence_prediction')
predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
......
......@@ -31,13 +31,15 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 3)
def test_bert_trainer(self, num_classes):
@parameterized.named_parameters(('single_cls', 1, False), ('3_cls', 3, False),
('3_cls_dictoutputs', 3, True))
def test_bert_trainer(self, num_classes, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
......
......@@ -183,7 +183,11 @@ class BertPretrainerV2(tf.keras.Model):
self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs)
sequence_output, _ = self.encoder_network(inputs)
outputs = self.encoder_network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
......
......@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT pretrainer model."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -111,7 +108,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config())
def test_bert_pretrainerv2(self):
@parameterized.parameters(True, False)
def test_bert_pretrainerv2(self, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
......@@ -119,7 +117,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
max_sequence_length=sequence_length,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
......
......@@ -64,7 +64,11 @@ class BertSpanLabeler(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs)
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
# This is an instance variable for ease of access to the underlying task
# network.
......
......@@ -14,10 +14,7 @@
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -30,12 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler
@keras_parameterized.run_all_keras_modes
class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_bert_trainer(self):
@parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
......
......@@ -67,7 +67,11 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs)
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
sequence_output)
......
......@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT token classifier."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -30,7 +27,8 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self):
@parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
......@@ -38,7 +36,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
max_sequence_length=sequence_length,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
num_classes = 3
......
......@@ -14,10 +14,6 @@
# ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
......@@ -64,6 +60,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
dict_outputs: Whether to use a dictionary as the model outputs.
"""
def __init__(self,
......@@ -79,6 +76,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
dropout_rate=0.1,
attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
dict_outputs=False,
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
......@@ -172,9 +170,16 @@ class AlbertTransformerEncoder(tf.keras.Model):
kernel_initializer=initializer,
name='pooler_transform')(
first_token_tensor)
if dict_outputs:
outputs = dict(
sequence_output=data,
pooled_output=cls_output,
)
else:
outputs = [data, cls_output]
super(AlbertTransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=[data, cls_output], **kwargs)
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
......
......@@ -109,7 +109,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data])
list_outputs = model.predict([word_id_data, mask_data, type_id_data])
# Creates a TransformerEncoder with max_sequence_length != sequence_length
max_sequence_length = 128
......@@ -124,6 +124,27 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
_ = model.predict([word_id_data, mask_data, type_id_data])
# Tests dictionary outputs.
test_network_dict = albert_transformer_encoder.AlbertTransformerEncoder(
vocab_size=vocab_size,
embedding_width=8,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
dict_outputs=True)
_ = test_network_dict([word_ids, mask, type_ids])
test_network_dict.set_weights(test_network.get_weights())
list_outputs = test_network([word_id_data, mask_data, type_id_data])
dict_outputs = test_network_dict(
dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data))
self.assertAllEqual(list_outputs[0], dict_outputs["sequence_output"])
self.assertAllEqual(list_outputs[1], dict_outputs["pooled_output"])
def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options.
......
......@@ -71,6 +71,7 @@ class BertEncoder(tf.keras.Model):
embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
"""
def __init__(self,
......@@ -90,6 +91,7 @@ class BertEncoder(tf.keras.Model):
output_range=None,
embedding_width=None,
embedding_layer=None,
dict_outputs=False,
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
......@@ -110,6 +112,7 @@ class BertEncoder(tf.keras.Model):
'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range,
'embedding_width': embedding_width,
'dict_outputs': dict_outputs,
}
word_ids = tf.keras.layers.Input(
......@@ -197,11 +200,16 @@ class BertEncoder(tf.keras.Model):
name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_encoder_outputs:
if dict_outputs:
outputs = dict(
sequence_output=encoder_outputs[-1],
pooled_output=cls_output,
encoder_outputs=encoder_outputs,
)
elif return_all_encoder_outputs:
outputs = [encoder_outputs, cls_output]
else:
outputs = [encoder_outputs[-1], cls_output]
super(BertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
......
......@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based text encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for transformer-based bert encoder network."""
# Import libraries
from absl.testing import parameterized
......@@ -64,6 +60,35 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
test_network_dict = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
inputs = dict(
input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)
_ = test_network_dict(inputs)
test_network_dict.set_weights(test_network.get_weights())
batch_size = 2
vocab_size = 100
num_types = 2
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
list_outputs = test_network([word_id_data, mask_data, type_id_data])
dict_outputs = test_network_dict(
dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data))
self.assertAllEqual(list_outputs[0], dict_outputs["sequence_output"])
self.assertAllEqual(list_outputs[1], dict_outputs["pooled_output"])
def test_all_encoder_outputs_network_creation(self):
hidden_size = 32
sequence_length = 21
......@@ -199,7 +224,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
initializer="glorot_uniform",
return_all_encoder_outputs=False,
output_range=-1,
embedding_width=16)
embedding_width=16,
dict_outputs=True)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize(
......
......@@ -93,6 +93,7 @@ class EncoderScaffold(tf.keras.Model):
"kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
"""
def __init__(self,
......@@ -106,6 +107,7 @@ class EncoderScaffold(tf.keras.Model):
hidden_cls=layers.Transformer,
hidden_cfg=None,
return_all_layer_outputs=False,
dict_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
......@@ -117,6 +119,7 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
if embedding_cls:
......@@ -200,7 +203,13 @@ class EncoderScaffold(tf.keras.Model):
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs:
if dict_outputs:
outputs = dict(
sequence_output=layer_output_data[-1],
pooled_output=cls_output,
encoder_outputs=layer_output_data,
)
elif return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
......@@ -219,6 +228,7 @@ class EncoderScaffold(tf.keras.Model):
'embedding_cfg': self._embedding_cfg,
'hidden_cfg': self._hidden_cfg,
'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs,
}
if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
......
......@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based text encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for EncoderScaffold network."""
from absl.testing import parameterized
import numpy as np
......@@ -218,16 +214,17 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
embedding_cfg=embedding_cfg,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
# Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
......@@ -237,7 +234,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data])
preds = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(preds["pooled_output"].shape, (3, hidden_size))
# Creates a EncoderScaffold with max_sequence_length != sequence_length
num_types = 7
......@@ -272,8 +270,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
_ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment