Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,13 +11,8 @@ ...@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for masked LM loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for masked LM loss."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -39,7 +34,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -39,7 +34,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output="predictions"): output="predictions"):
# First, create a transformer stack that we can use to get the LM's # First, create a transformer stack that we can use to get the LM's
# vocabulary weight. # vocabulary weight.
xformer_stack = networks.TransformerEncoder( xformer_stack = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length, sequence_length=sequence_length,
...@@ -204,5 +199,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -204,5 +199,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
expected_loss_data = 6.4222 expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
# Models # Models
Models are combinations of layers and networks that would be trained. Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks. These Several pre-built canned models are provided to train encoder networks.
models are intended as both convenience functions and canonical examples. These models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification * [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network. model containing a single classification head using the Classification network.
It can be used as a regression model as well. It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head over the sequence
TokenClassification network. output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span * [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token start-end predictor (that is, a model that predicts two values: a start token
...@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks. ...@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
* [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a * [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a
classification head using the Masked LM and Classification networks, classification head using the Masked LM and Classification networks,
respectively. respectively.
* [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for
retrieval tasks.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,10 +11,19 @@ ...@@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Models package definition.""" """Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks.
These models are intended as both convenience functions and canonical examples.
"""
from official.nlp.modeling.models.bert_classifier import BertClassifier from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetPretrainer
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,18 +11,13 @@ ...@@ -11,18 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
...@@ -37,7 +32,10 @@ class BertClassifier(tf.keras.Model): ...@@ -37,7 +32,10 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated. argument. If `num_classes` is set to 1, a regression network is instantiated.
Arguments: *Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
...@@ -45,8 +43,12 @@ class BertClassifier(tf.keras.Model): ...@@ -45,8 +43,12 @@ class BertClassifier(tf.keras.Model):
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
the encoder. encoder.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
""" """
def __init__(self, def __init__(self,
...@@ -55,15 +57,11 @@ class BertClassifier(tf.keras.Model): ...@@ -55,15 +57,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
cls_head=None,
**kwargs): **kwargs):
self._self_setattr_tracking = False self.num_classes = num_classes
self._network = network self.initializer = initializer
self._config = { self.use_encoder_pooler = use_encoder_pooler
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -73,36 +71,73 @@ class BertClassifier(tf.keras.Model): ...@@ -73,36 +71,73 @@ class BertClassifier(tf.keras.Model):
if use_encoder_pooler: if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs) outputs = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) if isinstance(outputs, list):
cls_inputs = outputs[1]
else:
cls_inputs = outputs['pooled_output']
cls_inputs = tf.keras.layers.Dropout(rate=dropout_rate)(cls_inputs)
else:
outputs = network(inputs)
if isinstance(outputs, list):
cls_inputs = outputs[0]
else:
cls_inputs = outputs['sequence_output']
self.classifier = networks.Classification( if cls_head:
input_width=cls_output.shape[-1], classifier = cls_head
num_classes=num_classes,
initializer=initializer,
output='logits',
name='sentence_prediction')
predictions = self.classifier(cls_output)
else: else:
sequence_output, _ = network(inputs) classifier = layers.ClassificationHead(
self.classifier = layers.ClassificationHead( inner_dim=0 if use_encoder_pooler else cls_inputs.shape[-1],
inner_dim=sequence_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name='sentence_prediction')
predictions = self.classifier(sequence_output)
predictions = classifier(cls_inputs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
self._network = network
self._cls_head = cls_head
config_dict = self._make_config_dict()
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
def _make_config_dict(self):
return {
'network': self._network,
'num_classes': self.num_classes,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head,
}
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,17 +11,14 @@ ...@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import """Tests for BERT trainer network."""
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_classifier from official.nlp.modeling.models import bert_classifier
...@@ -31,14 +28,15 @@ from official.nlp.modeling.models import bert_classifier ...@@ -31,14 +28,15 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase): class BertClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 3) @parameterized.named_parameters(('single_cls', 1, False), ('3_cls', 3, False),
def test_bert_trainer(self, num_classes): ('3_cls_dictoutputs', 3, True))
def test_bert_trainer(self, num_classes, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -56,17 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -56,17 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list()) self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
@parameterized.parameters(1, 2) @parameterized.named_parameters(
def test_bert_trainer_tensor_call(self, num_classes): ('single_cls', 1, False),
('2_cls', 2, False),
('single_cls_custom_head', 1, True),
('2_cls_custom_head', 2, True))
def test_bert_trainer_tensor_call(self, num_classes, use_custom_head):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2, sequence_length=2) cls_head = layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=num_classes) if use_custom_head else None
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=num_classes) test_network, num_classes=num_classes, cls_head=cls_head)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32) word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
...@@ -78,17 +81,21 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -78,17 +81,21 @@ class BertClassifierTest(keras_parameterized.TestCase):
# too complex: this simply ensures we're not hitting runtime errors.) # too complex: this simply ensures we're not hitting runtime errors.)
_ = bert_trainer_model([word_ids, mask, type_ids]) _ = bert_trainer_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self): @parameterized.named_parameters(
('default_cls_head', None),
('sngp_cls_head', layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=4)))
def test_serialize_deserialize(self, cls_head):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros') test_network, num_classes=4, initializer='zeros', cls_head=cls_head)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config() config = bert_trainer_model.get_config()
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,17 +11,14 @@ ...@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes
import collections
import copy import copy
from typing import List, Optional from typing import List, Optional
from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
...@@ -31,17 +28,18 @@ from official.nlp.modeling import networks ...@@ -31,17 +28,18 @@ from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertPretrainer(tf.keras.Model): class BertPretrainer(tf.keras.Model):
"""BERT network training model. """BERT pretraining model.
This is an implementation of the network structure surrounding a transformer [Note] Please use the new `BertPretrainerV2` for your projects.
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertPretrainer allows a user to pass in a transformer stack, and The BertPretrainer allows a user to pass in a transformer stack, and
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
Arguments: *Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. and a classification output.
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
...@@ -52,8 +50,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -52,8 +50,8 @@ class BertPretrainer(tf.keras.Model):
None, no activation will be used. None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
...@@ -65,21 +63,12 @@ class BertPretrainer(tf.keras.Model): ...@@ -65,21 +63,12 @@ class BertPretrainer(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._config = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use # Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy # when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.) # because we'll be adding another tensor to the copy later.)
network_inputs = self.encoder.inputs network_inputs = network.inputs
inputs = copy.copy(network_inputs) inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
...@@ -87,7 +76,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -87,7 +76,7 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use # Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list # the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below. # object contains the additional input added below.
sequence_output, cls_output = self.encoder(network_inputs) sequence_output, cls_output = network(network_inputs)
# The encoder network may get outputs from all layers. # The encoder network may get outputs from all layers.
if isinstance(sequence_output, list): if isinstance(sequence_output, list):
...@@ -95,7 +84,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -95,7 +84,8 @@ class BertPretrainer(tf.keras.Model):
if isinstance(cls_output, list): if isinstance(cls_output, list):
cls_output = cls_output[-1] cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1] sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions: if sequence_output_length is not None and (sequence_output_length <
num_token_predictions):
raise ValueError( raise ValueError(
"The passed network's output length is %s, which is less than the " "The passed network's output length is %s, which is less than the "
'requested num_token_predictions %s.' % 'requested num_token_predictions %s.' %
...@@ -108,48 +98,74 @@ class BertPretrainer(tf.keras.Model): ...@@ -108,48 +98,74 @@ class BertPretrainer(tf.keras.Model):
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
if embedding_table is None: if embedding_table is None:
embedding_table = self.encoder.get_embedding_table() embedding_table = network.get_embedding_table()
self.masked_lm = layers.MaskedLM( masked_lm = layers.MaskedLM(
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='cls/predictions') name='cls/predictions')
lm_outputs = self.masked_lm( lm_outputs = masked_lm(
sequence_output, masked_positions=masked_lm_positions) sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification( classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='classification') name='classification')
sentence_outputs = self.classification(cls_output) sentence_outputs = classification(cls_output)
super(BertPretrainer, self).__init__( super(BertPretrainer, self).__init__(
inputs=inputs, inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs), outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs) **kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.encoder = network
self.classification = classification
self.masked_lm = masked_lm
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
# TODO(hongkuny): Migrate to BertPretrainerV2 for all usages.
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable @gin.configurable
class BertPretrainerV2(tf.keras.Model): class BertPretrainerV2(tf.keras.Model):
"""BERT pretraining model V2. """BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the Adds the masked language model head and optional classification heads upon the
transformer encoder. transformer encoder.
Arguments: Args:
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If mlm_activation: The activation (if any) to use in the masked LM network. If
...@@ -158,11 +174,16 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -158,11 +174,16 @@ class BertPretrainerV2(tf.keras.Model):
to a Glorot uniform initializer. to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder classification_heads: A list of optional head layers to transform on encoder
sequence outputs. sequence outputs.
customized_masked_lm: A customized masked_lm layer. If None, will create
a standard layer from `layers.MaskedLM`; if not None, will use the
specified masked_lm layer. Above arguments `mlm_activation` and
`mlm_initializer` will be ignored.
name: The name of the model. name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary. dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names. head names, and also outputs from `encoder_network`, keyed by
`sequence_output` and `encoder_outputs` (if any).
""" """
def __init__( def __init__(
...@@ -171,27 +192,24 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -171,27 +192,24 @@ class BertPretrainerV2(tf.keras.Model):
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf.keras.layers.Layer]] = None, classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
name: str = 'bert', name: str = 'bert',
**kwargs): **kwargs):
self._self_setattr_tracking = False super().__init__(self, name=name, **kwargs)
self._config = { self._config = {
'encoder_network': encoder_network, 'encoder_network': encoder_network,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads, 'classification_heads': classification_heads,
'name': name, 'name': name,
} }
self.encoder_network = encoder_network self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs) inputs = copy.copy(self.encoder_network.inputs)
sequence_output, _ = self.encoder_network(inputs)
self.classification_heads = classification_heads or [] self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len( if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads): self.classification_heads):
raise ValueError('Classification heads should have unique names.') raise ValueError('Classification heads should have unique names.')
outputs = dict() self.masked_lm = customized_masked_lm or layers.MaskedLM(
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(), embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=mlm_initializer,
...@@ -199,13 +217,45 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -199,13 +217,45 @@ class BertPretrainerV2(tf.keras.Model):
masked_lm_positions = tf.keras.layers.Input( masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32) shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm( self.inputs = inputs
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: def call(self, inputs):
outputs[cls_head.name] = cls_head(sequence_output) if isinstance(inputs, list):
logging.warning('List inputs to BertPretrainer are discouraged.')
inputs = dict([
(ref.name, tensor) for ref, tensor in zip(self.inputs, inputs)
])
super(BertPretrainerV2, self).__init__( outputs = dict()
inputs=inputs, outputs=outputs, name=name, **kwargs) encoder_network_outputs = self.encoder_network(inputs)
if isinstance(encoder_network_outputs, list):
outputs['pooled_output'] = encoder_network_outputs[1]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if isinstance(encoder_network_outputs[0], list):
outputs['encoder_outputs'] = encoder_network_outputs[0]
outputs['sequence_output'] = encoder_network_outputs[0][-1]
else:
outputs['sequence_output'] = encoder_network_outputs[0]
elif isinstance(encoder_network_outputs, dict):
outputs = encoder_network_outputs
else:
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output']
# Inference may not have masked_lm_positions and mlm_logits is not needed.
if 'masked_lm_positions' in inputs:
masked_lm_positions = inputs['masked_lm_positions']
outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
cls_outputs = cls_head(sequence_output)
if isinstance(cls_outputs, dict):
outputs.update(cls_outputs)
else:
outputs[cls_head.name] = cls_outputs
return outputs
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,16 +11,15 @@ ...@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import """Tests for BERT pretrainer model."""
from __future__ import division import itertools
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -35,8 +34,10 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -35,8 +34,10 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -68,7 +69,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -68,7 +69,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
...@@ -90,8 +91,8 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -90,8 +91,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
...@@ -109,36 +110,112 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -109,36 +110,112 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(), self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config()) new_bert_trainer_model.get_config())
def test_bert_pretrainerv2(self):
class BertPretrainerV2Test(keras_parameterized.TestCase):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
(False, True),
(False, True),
))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs,
use_customized_masked_lm, has_masked_lm_positions):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( hidden_size = 48
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
return_all_encoder_outputs=return_all_encoder_outputs,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
if use_customized_masked_lm:
customized_masked_lm = layers.MaskedLM(
embedding_table=test_network.get_embedding_table())
else:
customized_masked_lm = None
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network) encoder_network=test_network, customized_masked_lm=customized_masked_lm)
num_token_predictions = 20 num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) inputs = dict(
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32) input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32))
if has_masked_lm_positions:
inputs['masked_lm_positions'] = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
expected_keys = ['sequence_output', 'pooled_output']
if has_encoder_outputs:
expected_keys.append('encoder_outputs')
if has_masked_lm_positions:
expected_keys.append('mlm_logits')
self.assertSameElements(outputs.keys(), expected_keys)
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list()) if has_masked_lm_positions:
self.assertAllEqual(expected_lm_shape,
outputs['mlm_logits'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape,
outputs['sequence_output'].shape.as_list())
expected_pooled_output_shape = [None, hidden_size]
self.assertAllEqual(expected_pooled_output_shape,
outputs['pooled_output'].shape.as_list())
def test_multiple_cls_outputs(self):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
dict_outputs=True)
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network,
classification_heads=[layers.MultiClsHeads(
inner_dim=5, cls_list=[('foo', 2), ('bar', 3)])])
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
masked_lm_positions=tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs)
self.assertEqual(outputs['foo'].shape.as_list(), [None, 2])
self.assertEqual(outputs['bar'].shape.as_list(), [None, 3])
def test_v2_serialize_deserialize(self): def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,14 +11,10 @@ ...@@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -32,17 +28,20 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -32,17 +28,20 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805). for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
Arguments: *Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
initializer: The initializer (if any) to use in the span labeling network. initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logit`' or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
...@@ -50,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -50,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -65,16 +57,25 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -65,16 +57,25 @@ class BertSpanLabeler(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs) outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
# The input network (typically a transformer model) may get outputs from all
# layers. When this case happens, we retrieve the last layer output.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
# This is an instance variable for ease of access to the underlying task # This is an instance variable for ease of access to the underlying task
# network. # network.
self.span_labeling = networks.SpanLabeling( span_labeling = networks.SpanLabeling(
input_width=sequence_output.shape[-1], input_width=sequence_output.shape[-1],
initializer=initializer, initializer=initializer,
output=output, output=output,
name='span_labeling') name='span_labeling')
start_logits, end_logits = self.span_labeling(sequence_output) start_logits, end_logits = span_labeling(sequence_output)
# Use identity layers wrapped in lambdas to explicitly name the output # Use identity layers wrapped in lambdas to explicitly name the output
# tensors. This allows us to use string-keyed dicts in Keras fit/predict/ # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
...@@ -88,15 +89,36 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -88,15 +89,36 @@ class BertSpanLabeler(tf.keras.Model):
logits = [start_logits, end_logits] logits = [start_logits, end_logits]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertSpanLabeler, self).__init__( super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs) inputs=inputs, outputs=logits, **kwargs)
self._network = network
config_dict = {
'network': network,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.span_labeling = span_labeling
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,13 +11,10 @@ ...@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import """Tests for BERT trainer network."""
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,13 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler ...@@ -30,13 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertSpanLabelerTest(keras_parameterized.TestCase): class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -59,9 +57,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -59,9 +57,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names.""" """Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -80,8 +76,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -80,8 +76,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -100,7 +95,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -100,7 +95,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,18 +11,12 @@ ...@@ -11,18 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertTokenClassifier(tf.keras.Model): class BertTokenClassifier(tf.keras.Model):
...@@ -36,15 +30,21 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,15 +30,21 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes` instantiates a token classification network based on the passed `num_classes`
argument. argument.
Arguments: *Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
""" """
def __init__(self, def __init__(self,
...@@ -53,15 +53,8 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -53,15 +53,8 @@ class BertTokenClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
dropout_rate=0.1, dropout_rate=0.1,
output_encoder_outputs=False,
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -70,27 +63,70 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -70,27 +63,70 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs) outputs = network(inputs)
sequence_output = tf.keras.layers.Dropout( if isinstance(outputs, list):
rate=dropout_rate)(sequence_output) sequence_output = outputs[0]
else:
self.classifier = networks.TokenClassification( sequence_output = outputs['sequence_output']
input_width=sequence_output.shape[-1], sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
num_classes=num_classes, sequence_output)
initializer=initializer,
output=output, classifier = tf.keras.layers.Dense(
name='classification') num_classes,
predictions = self.classifier(sequence_output) activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')
logits = classifier(sequence_output)
if output == 'logits':
output_tensors = {'logits': logits}
elif output == 'predictions':
output_tensors = {
'predictions': tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
}
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
if output_encoder_outputs:
output_tensors['encoder_outputs'] = sequence_output
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertTokenClassifier, self).__init__( super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=output_tensors, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
'output_encoder_outputs': output_encoder_outputs
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
self.logits = logits
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,13 +11,10 @@ ...@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import """Tests for BERT token classifier."""
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,19 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier ...@@ -30,19 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase): class BertTokenClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters((True, True), (False, False))
def test_bert_trainer(self, dict_outputs, output_encoder_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( hidden_size = 768
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=dict_outputs,
hidden_size=hidden_size)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, test_network,
num_classes=num_classes) num_classes=num_classes,
output_encoder_outputs=output_encoder_outputs)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -50,19 +54,25 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -50,19 +54,25 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids]) outputs = bert_trainer_model([word_ids, mask, type_ids])
if output_encoder_outputs:
logits = outputs['logits']
encoder_outputs = outputs['encoder_outputs']
self.assertAllEqual(encoder_outputs.shape.as_list(),
[None, sequence_length, hidden_size])
else:
logits = outputs['logits']
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes] expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape, self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
sequence_outs.shape.as_list())
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2, max_sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
...@@ -82,8 +92,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -82,8 +92,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf.keras.Model):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Args:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either `logits` or
`predictions`. If set to `predictions`, it will output the embedding
producted by transformer network.
"""
def __init__(self,
network: tf.keras.Model,
max_seq_length: int = 32,
normalize: bool = True,
logit_scale: float = 1.0,
logit_margin: float = 0.0,
output: str = 'logits',
**kwargs) -> None:
if output == 'logits':
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
else:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
left_outputs = network(left_inputs)
if isinstance(left_outputs, list):
left_sequence_output, left_encoded = left_outputs
else:
left_sequence_output = left_outputs['sequence_output']
left_encoded = left_outputs['pooled_output']
if normalize:
left_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
left_encoded)
if output == 'logits':
right_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
right_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
right_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids]
right_outputs = network(right_inputs)
if isinstance(right_outputs, list):
_, right_encoded = right_outputs
else:
right_encoded = right_outputs['pooled_output']
if normalize:
right_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
right_encoded)
dot_products = layers.MatMulWithMargin(
logit_scale=logit_scale,
logit_margin=logit_margin,
name='dot_product')
inputs = [
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
]
left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = dict(left_logits=left_logits, right_logits=right_logits)
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = dict(
sequence_output=left_sequence_output, pooled_output=left_encoded)
else:
raise ValueError('output type %s is not supported' % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
config_dict = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.network = network
def get_config(self):
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.network)
return items
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dual encoder network."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import dual_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class DualEncoderTest(keras_parameterized.TestCase):
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder(self, hidden_size, output):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the dual encoder model.
vocab_size = 100
sequence_length = 512
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
hidden_size=hidden_size,
sequence_length=sequence_length,
dict_outputs=True)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
left_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
if output == 'logits':
outputs = dual_encoder_model([
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
])
_ = outputs['left_logits']
elif output == 'predictions':
outputs = dual_encoder_model([left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape.
expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape,
outputs['sequence_output'].shape.as_list())
left_encoded = outputs['pooled_output']
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.)
sequence_length = 2
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
# Invoke the model model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
if output == 'logits':
_ = dual_encoder_model(
[word_ids, mask, type_ids, word_ids, mask, type_ids])
elif output == 'predictions':
_ = dual_encoder_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
"""Validate that the dual encoder model can be serialized / deserialized."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.)
sequence_length = 32
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network. (Note that all the
# args are different, so we can catch any serialization mismatches.)
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output='predictions')
# Create another dual encoder model via serialization and deserialization.
config = dual_encoder_model.get_config()
new_dual_encoder = dual_encoder.DualEncoder.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_dual_encoder.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(dual_encoder_model.get_config(),
new_dual_encoder.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,15 +11,12 @@ ...@@ -11,15 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models.""" """Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy import copy
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -39,7 +36,10 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -39,7 +36,10 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side) model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives. that are used to create the training objectives.
Arguments: *Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer network for generator, this network should generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output. output a sequence output and an optional classification output.
discriminator_network: A transformer network for discriminator, this network discriminator_network: A transformer network for discriminator, this network
...@@ -47,15 +47,13 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -47,15 +47,13 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size: Size of generator output vocabulary vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or output_type: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
disallow_correct: Whether to disallow the generator to generate the exact disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence same token in the original sentence
""" """
...@@ -65,8 +63,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -65,8 +63,6 @@ class ElectraPretrainer(tf.keras.Model):
discriminator_network, discriminator_network,
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length,
last_hidden_dim,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -79,8 +75,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -79,8 +75,6 @@ class ElectraPretrainer(tf.keras.Model):
'discriminator_network': discriminator_network, 'discriminator_network': discriminator_network,
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -94,8 +88,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -94,8 +88,6 @@ class ElectraPretrainer(tf.keras.Model):
self.discriminator_network = discriminator_network self.discriminator_network = discriminator_network
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
...@@ -108,10 +100,15 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -108,10 +100,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim, inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=mlm_initializer,
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1, kernel_initializer=mlm_initializer)
...@@ -123,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -123,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns: Returns:
outputs: A dict of pretrainer model outputs, including outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
indicating logits on masked positions. tensor indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating (2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task. logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating (3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task. logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating (4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task. target labels for discriminator replaced token detection task.
""" """
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
...@@ -138,14 +135,11 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -138,14 +135,11 @@ class ElectraPretrainer(tf.keras.Model):
masked_lm_positions = inputs['masked_lm_positions'] masked_lm_positions = inputs['masked_lm_positions']
### Generator ### ### Generator ###
sequence_output, cls_output = self.generator_network( sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers. # The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list): if isinstance(sequence_output, list):
sequence_output = sequence_output[-1] sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions) lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
sentence_outputs = self.classification(sequence_output) sentence_outputs = self.classification(sequence_output)
...@@ -156,16 +150,17 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -156,16 +150,17 @@ class ElectraPretrainer(tf.keras.Model):
### Discriminator ### ### Discriminator ###
disc_input = fake_data['inputs'] disc_input = fake_data['inputs']
disc_label = fake_data['is_fake_tokens'] disc_label = fake_data['is_fake_tokens']
disc_sequence_output, _ = self.discriminator_network([ disc_sequence_output = self.discriminator_network([
disc_input['input_word_ids'], disc_input['input_mask'], disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids'] disc_input['input_type_ids']
]) ])['sequence_output']
# The discriminator encoder network may get outputs from all layers. # The discriminator encoder network may get outputs from all layers.
if isinstance(disc_sequence_output, list): if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1] disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output) disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1) disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = { outputs = {
...@@ -181,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -181,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator. """Generate corrupted data for discriminator.
Args: Args:
inputs: A dict of all inputs, same as the input of call() function inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications duplicate: Whether to copy the original inputs dict during modifications
...@@ -214,6 +209,12 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -214,6 +209,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens 'sampled_tokens': sampled_tokens
} }
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self): def get_config(self):
return self._config return self._config
...@@ -226,16 +227,18 @@ def scatter_update(sequence, updates, positions): ...@@ -226,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence. """Scatter-update a sequence.
Args: Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
updates: A tensor of size batch_size*seq_len(*depth) tensor.
positions: A [batch_size, n_positions] tensor updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns: Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] updated_sequence: A `[batch_size, seq_len]` or
tensor of "sequence" with elements at "positions" replaced by the values `[batch_size, seq_len, depth]` tensor of "sequence" with elements at
at "updates". Updates to index 0 are ignored. If there are duplicated "positions" replaced by the values at "updates". Updates to index 0 are
positions the update is only applied once. ignored. If there are duplicated positions the update is only
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated. updated.
""" """
shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3]) shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
...@@ -288,14 +291,14 @@ def sample_from_softmax(logits, disallow=None): ...@@ -288,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick. """Implement softmax sampling using gumbel softmax trick.
Args: Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
the generator output logits for each masked position. indicating the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise, disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size] this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
indicating the true word id in each masked position. indicating the true word id in each masked position.
Returns: Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
tensor indicating the sampled word id in each masked position. tensor indicating the sampled word id in each masked position.
""" """
if disallow is not None: if disallow is not None:
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,12 +11,8 @@ ...@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
from __future__ import absolute_import """Tests for ELECTRA pre trainer network."""
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
...@@ -35,10 +31,16 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -35,10 +31,16 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. # Build a transformer network to use within the ELECTRA trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
test_discriminator_network = networks.TransformerEncoder( num_layers=2,
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) max_sequence_length=sequence_length,
dict_outputs=True)
test_discriminator_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=True)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -48,8 +50,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -48,8 +50,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
discriminator_network=test_discriminator_network, discriminator_network=test_discriminator_network,
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
...@@ -89,10 +89,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -89,10 +89,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the ELECTRA trainer. (Here, we # Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.) # use a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer( eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
...@@ -101,7 +101,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -101,7 +101,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
...@@ -127,10 +126,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -127,10 +126,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the ELECTRA trainer can be serialized and deserialized.""" """Validate that the ELECTRA trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. (Note that all the args # Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
...@@ -140,7 +139,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -140,7 +139,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
"""
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
EOS_ID = 1
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continuous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def __init__(self,
vocab_size=33708,
embedding_width=512,
dropout_rate=0.0,
padded_decode=False,
decode_max_length=None,
extra_decode_length=0,
beam_size=4,
alpha=0.6,
encoder_layer=None,
decoder_layer=None,
eos_id=EOS_ID,
**kwargs):
"""Initialize layers to build Transformer model.
Args:
vocab_size: Size of vocabulary.
embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
decode_max_length: maximum number of steps to decode a sequence.
extra_decode_length: Beam search will run extra steps to decode.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._dropout_rate = dropout_rate
self._padded_decode = padded_decode
self._decode_max_length = decode_max_length
self._extra_decode_length = extra_decode_length
self._beam_size = beam_size
self._alpha = alpha
self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer(
mean=0., stddev=self._embedding_width**-0.5),
scale_factor=self._embedding_width**0.5)
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding(
hidden_size=self._embedding_width)
self.encoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.decoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"hidden_size": self._embedding_width,
"dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length,
"eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size,
"alpha": self._alpha,
"encoder_layer": self.encoder_layer,
"decoder_layer": self.decoder_layer
}
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Returns:
If targets is defined, then return logits for each word in the target
sequence, which is a float tensor with shape
`(batch_size, target_length, vocab_size)`. If target is `None`, then
generate output sequence one token at a time and
returns a dictionary {
outputs: `(batch_size, decoded_length)`
scores: `(batch_size, 1)`}
Even when `float16` is used, the output tensor(s) are always `float32`.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
sources = inputs["inputs"]
targets = inputs.get("targets", None)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=sources.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
if self._padded_decode:
max_decode_length = self._decode_max_length
else:
max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0]
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache = {
str(layer): {
"key":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype),
"value":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype)
} for layer in range(self.decoder_layer.num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=self._eos_id,
padded_decode=self._padded_decode,
dtype=self.compute_dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
decoder_self_attention_mask = tf.linalg.band_part(
tf.ones([max_decode_length, max_decode_length],
dtype=self.compute_dtype), -1, 0)
decoder_self_attention_mask = tf.reshape(
decoder_self_attention_mask, [1, max_decode_length, max_decode_length])
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape `(batch_size *
beam_size, i + 1)`.
i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
# decoder_input = self.embedding_softmax_layer(decoder_input)
source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i]
if self._padded_decode:
# indexing does not work on TPU.
bias_shape = decoder_self_attention_mask.shape.as_list()
self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else:
self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = cache.get("encoder_decoder_attention_mask")
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
decoder_outputs = self.decoder_layer(
decoder_input,
cache.get("encoder_outputs"),
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask,
cache=cache,
decode_loop_step=i if self._padded_decode else None)
decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
return symbols_to_logits_fn
class TransformerEncoder(tf.keras.layers.Layer):
"""Transformer encoder.
Transformer encoder is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
super(TransformerEncoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask])
output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor)
return output_tensor
class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder.
Like the encoder, the decoder is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.decoder_layers = []
for i in range(self.num_layers):
self.decoder_layers.append(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
target,
memory,
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape `(batch_size, i, value_channels)`},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder.
float32 tensor with shape `(batch_size, target_length, hidden_size`).
"""
output_tensor = target
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
return self.output_normalization(output_tensor)
def attention_initializer(hidden_size):
"""Initializer for attention layers in Seq2SeqTransformer."""
hidden_size = int(hidden_size)
limit = math.sqrt(6.0 / (hidden_size + hidden_size))
return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
targets=np.zeros((batch_size, 8), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
@parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode):
decode_max_length = 10
model = self._build_model(padded_decode, decode_max_length)
class SaveModule(tf.Module):
def __init__(self, model):
super(SaveModule, self).__init__()
self.model = model
@tf.function
def serve(self, inputs):
return self.model.call(dict(inputs=inputs))
save_module = SaveModule(model)
if padded_decode:
tensor_shape = (4, 10)
else:
tensor_shape = (None, None)
signatures = dict(
serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
class XLNetMaskedLM(tf.keras.layers.Layer):
"""XLNet pretraining head."""
def __init__(self,
vocab_size: int,
hidden_size: int,
initializer: str = 'glorot_uniform',
activation: str = 'gelu',
name=None,
**kwargs):
super().__init__(name=name, **kwargs)
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._initializer = initializer
self._activation = activation
def build(self, input_shape):
self.dense = tf.keras.layers.Dense(
units=self._hidden_size,
activation=self._activation,
kernel_initializer=self._initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super().build(input_shape)
def call(self,
sequence_data: tf.Tensor,
embedding_table: tf.Tensor):
lm_data = self.dense(sequence_data)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
return logits
def get_config(self) -> Mapping[str, Any]:
config = {
'vocab_size':
self._vocab_size,
'hidden_size':
self._hidden_size,
'initializer':
self._initializer
}
base_config = super(XLNetMaskedLM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetPretrainer(tf.keras.Model):
"""XLNet-based pretrainer.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
mlm_activation: The activation (if any) to use in the Masked LM network. If
None, then no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Defaults
to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
mlm_activation=None,
mlm_initializer='glorot_uniform',
name: str = None,
**kwargs):
super().__init__(name=name, **kwargs)
self._config = {
'network': network,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
}
self._network = network
self._hidden_size = network.get_config()['hidden_size']
self._vocab_size = network.get_config()['vocab_size']
self._activation = mlm_activation
self._initializer = mlm_initializer
self._masked_lm = XLNetMaskedLM(
vocab_size=self._vocab_size,
hidden_size=self._hidden_size,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
masked_tokens = inputs['masked_tokens']
permutation_mask = inputs['permutation_mask']
target_mapping = inputs['target_mapping']
state = inputs.get('state', None)
attention_output, state = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=None,
state=state,
permutation_mask=permutation_mask,
target_mapping=target_mapping,
masked_tokens=masked_tokens)
embedding_table = self._network.get_embedding_lookup_table()
mlm_outputs = self._masked_lm(
sequence_data=attention_output,
embedding_table=embedding_table)
return mlm_outputs, state
def get_config(self) -> Mapping[str, Any]:
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetClassifier(tf.keras.Model):
"""Classifier model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
num_classes: int,
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
**kwargs):
super().__init__(**kwargs)
self._network = network
self._initializer = initializer
self._summary_type = summary_type
self._num_classes = num_classes
self._config = {
'network': network,
'initializer': initializer,
'num_classes': num_classes,
'summary_type': summary_type,
'dropout_rate': dropout_rate,
}
if summary_type == 'last':
cls_token_idx = -1
elif summary_type == 'first':
cls_token_idx = 0
else:
raise ValueError('Invalid summary type provided: %s.' % summary_type)
self.classifier = layers.ClassificationHead(
inner_dim=network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx,
name='sentence_prediction')
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_mask = tf.cast(inputs['input_mask'], tf.float32)
state = inputs.get('mems', None)
attention_output, _ = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
state=state)
logits = self.classifier(attention_output)
return logits
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int = 5,
end_n_top: int = 5,
dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._config = {
'network': network,
'start_n_top': start_n_top,
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
'span_labeling_activation': span_labeling_activation,
'initializer': initializer,
}
network_config = network.get_config()
try:
input_width = network_config['inner_size']
self._xlnet_base = True
except KeyError:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width = network_config['intermediate_size']
self._xlnet_base = False
self._network = network
self._initializer = initializer
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self._dropout_rate = dropout_rate
self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling(
input_width=input_width,
start_n_top=self._start_n_top,
end_n_top=self._end_n_top,
activation=self._activation,
dropout_rate=self._dropout_rate,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
input_mask = inputs['input_mask']
class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None)
if self._xlnet_base:
attention_output, _ = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=input_mask)
else:
network_output_dict = self._network(dict(
input_word_ids=input_word_ids,
input_type_ids=input_type_ids,
input_mask=input_mask))
attention_output = network_output_dict['sequence_output']
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
paragraph_mask=paragraph_mask,
start_positions=start_positions)
return outputs
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment