Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for masked LM loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for masked LM loss."""
import numpy as np
import tensorflow as tf
......@@ -39,7 +34,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output="predictions"):
# First, create a transformer stack that we can use to get the LM's
# vocabulary weight.
xformer_stack = networks.TransformerEncoder(
xformer_stack = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=1,
sequence_length=sequence_length,
......@@ -204,5 +199,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__":
tf.test.main()
# Models
Models are combinations of layers and networks that would be trained.
Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks. These
models are intended as both convenience functions and canonical examples.
Several pre-built canned models are provided to train encoder networks.
These models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network.
It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the
TokenClassification network.
classification model containing a single classification head over the sequence
output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token
......@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
* [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a
classification head using the Masked LM and Classification networks,
respectively.
* [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for
retrieval tasks.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Models package definition."""
"""Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks.
These models are intended as both convenience functions and canonical examples.
"""
from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetPretrainer
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,18 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text')
......@@ -37,7 +32,10 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
......@@ -45,8 +43,12 @@ class BertClassifier(tf.keras.Model):
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
"""
def __init__(self,
......@@ -55,15 +57,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform',
dropout_rate=0.1,
use_encoder_pooler=True,
cls_head=None,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
self.num_classes = num_classes
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -73,36 +71,73 @@ class BertClassifier(tf.keras.Model):
if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
outputs = network(inputs)
if isinstance(outputs, list):
cls_inputs = outputs[1]
else:
cls_inputs = outputs['pooled_output']
cls_inputs = tf.keras.layers.Dropout(rate=dropout_rate)(cls_inputs)
else:
outputs = network(inputs)
if isinstance(outputs, list):
cls_inputs = outputs[0]
else:
cls_inputs = outputs['sequence_output']
self.classifier = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output='logits',
name='sentence_prediction')
predictions = self.classifier(cls_output)
if cls_head:
classifier = cls_head
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
classifier = layers.ClassificationHead(
inner_dim=0 if use_encoder_pooler else cls_inputs.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
predictions = classifier(cls_inputs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
self._network = network
self._cls_head = cls_head
config_dict = self._make_config_dict()
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
@property
def checkpoint_items(self):
return dict(encoder=self._network)
items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def _make_config_dict(self):
return {
'network': self._network,
'num_classes': self.num_classes,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head,
}
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT trainer network."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_classifier
......@@ -31,14 +28,15 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 3)
def test_bert_trainer(self, num_classes):
@parameterized.named_parameters(('single_cls', 1, False), ('3_cls', 3, False),
('3_cls_dictoutputs', 3, True))
def test_bert_trainer(self, num_classes, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
......@@ -56,17 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
@parameterized.parameters(1, 2)
def test_bert_trainer_tensor_call(self, num_classes):
@parameterized.named_parameters(
('single_cls', 1, False),
('2_cls', 2, False),
('single_cls_custom_head', 1, True),
('2_cls_custom_head', 2, True))
def test_bert_trainer_tensor_call(self, num_classes, use_custom_head):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
cls_head = layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=num_classes) if use_custom_head else None
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=num_classes)
test_network, num_classes=num_classes, cls_head=cls_head)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
......@@ -78,17 +81,21 @@ class BertClassifierTest(keras_parameterized.TestCase):
# too complex: this simply ensures we're not hitting runtime errors.)
_ = bert_trainer_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
@parameterized.named_parameters(
('default_cls_head', None),
('sngp_cls_head', layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=4)))
def test_serialize_deserialize(self, cls_head):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros')
test_network, num_classes=4, initializer='zeros', cls_head=cls_head)
# Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config()
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes
import collections
import copy
from typing import List, Optional
from absl import logging
import gin
import tensorflow as tf
......@@ -31,17 +28,18 @@ from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text')
class BertPretrainer(tf.keras.Model):
"""BERT network training model.
"""BERT pretraining model.
This is an implementation of the network structure surrounding a transformer
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
[Note] Please use the new `BertPretrainerV2` for your projects.
The BertPretrainer allows a user to pass in a transformer stack, and
instantiates the masked language model and classification networks that are
used to create the training objectives.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output.
num_classes: Number of classes to predict from the classification network.
......@@ -52,8 +50,8 @@ class BertPretrainer(tf.keras.Model):
None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
output: The output style for this network. Can be either `logits` or
`predictions`.
"""
def __init__(self,
......@@ -65,21 +63,12 @@ class BertPretrainer(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.)
network_inputs = self.encoder.inputs
network_inputs = network.inputs
inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can
......@@ -87,7 +76,7 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below.
sequence_output, cls_output = self.encoder(network_inputs)
sequence_output, cls_output = network(network_inputs)
# The encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
......@@ -95,7 +84,8 @@ class BertPretrainer(tf.keras.Model):
if isinstance(cls_output, list):
cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions:
if sequence_output_length is not None and (sequence_output_length <
num_token_predictions):
raise ValueError(
"The passed network's output length is %s, which is less than the "
'requested num_token_predictions %s.' %
......@@ -108,48 +98,74 @@ class BertPretrainer(tf.keras.Model):
inputs.append(masked_lm_positions)
if embedding_table is None:
embedding_table = self.encoder.get_embedding_table()
self.masked_lm = layers.MaskedLM(
embedding_table = network.get_embedding_table()
masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
output=output,
name='cls/predictions')
lm_outputs = self.masked_lm(
lm_outputs = masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification(
classification = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
sentence_outputs = self.classification(cls_output)
sentence_outputs = classification(cls_output)
super(BertPretrainer, self).__init__(
inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.encoder = network
self.classification = classification
self.masked_lm = masked_lm
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# TODO(hongkuny): Migrate to BertPretrainerV2 for all usages.
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class BertPretrainerV2(tf.keras.Model):
"""BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the
transformer encoder.
Arguments:
Args:
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If
......@@ -158,11 +174,16 @@ class BertPretrainerV2(tf.keras.Model):
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
sequence outputs.
customized_masked_lm: A customized masked_lm layer. If None, will create
a standard layer from `layers.MaskedLM`; if not None, will use the
specified masked_lm layer. Above arguments `mlm_activation` and
`mlm_initializer` will be ignored.
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by
head names.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
`sequence_output` and `encoder_outputs` (if any).
"""
def __init__(
......@@ -171,27 +192,24 @@ class BertPretrainerV2(tf.keras.Model):
mlm_activation=None,
mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
name: str = 'bert',
**kwargs):
self._self_setattr_tracking = False
super().__init__(self, name=name, **kwargs)
self._config = {
'encoder_network': encoder_network,
'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads,
'name': name,
}
self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs)
sequence_output, _ = self.encoder_network(inputs)
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
outputs = dict()
self.masked_lm = layers.MaskedLM(
self.masked_lm = customized_masked_lm or layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
......@@ -199,13 +217,45 @@ class BertPretrainerV2(tf.keras.Model):
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
self.inputs = inputs
def call(self, inputs):
if isinstance(inputs, list):
logging.warning('List inputs to BertPretrainer are discouraged.')
inputs = dict([
(ref.name, tensor) for ref, tensor in zip(self.inputs, inputs)
])
outputs = dict()
encoder_network_outputs = self.encoder_network(inputs)
if isinstance(encoder_network_outputs, list):
outputs['pooled_output'] = encoder_network_outputs[1]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if isinstance(encoder_network_outputs[0], list):
outputs['encoder_outputs'] = encoder_network_outputs[0]
outputs['sequence_output'] = encoder_network_outputs[0][-1]
else:
outputs['sequence_output'] = encoder_network_outputs[0]
elif isinstance(encoder_network_outputs, dict):
outputs = encoder_network_outputs
else:
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output']
# Inference may not have masked_lm_positions and mlm_logits is not needed.
if 'masked_lm_positions' in inputs:
masked_lm_positions = inputs['masked_lm_positions']
outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output)
super(BertPretrainerV2, self).__init__(
inputs=inputs, outputs=outputs, name=name, **kwargs)
cls_outputs = cls_head(sequence_output)
if isinstance(cls_outputs, dict):
outputs.update(cls_outputs)
else:
outputs[cls_head.name] = cls_outputs
return outputs
@property
def checkpoint_items(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT pretrainer model."""
import itertools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer
......@@ -35,8 +34,10 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network.
num_classes = 3
......@@ -68,7 +69,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network.
......@@ -90,8 +91,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......@@ -109,36 +110,112 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config())
def test_bert_pretrainerv2(self):
class BertPretrainerV2Test(keras_parameterized.TestCase):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
(False, True),
(False, True),
))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs,
use_customized_masked_lm, has_masked_lm_positions):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
return_all_encoder_outputs=return_all_encoder_outputs,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
if use_customized_masked_lm:
customized_masked_lm = layers.MaskedLM(
embedding_table=test_network.get_embedding_table())
else:
customized_masked_lm = None
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network)
encoder_network=test_network, customized_masked_lm=customized_masked_lm)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32))
if has_masked_lm_positions:
inputs['masked_lm_positions'] = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
expected_keys = ['sequence_output', 'pooled_output']
if has_encoder_outputs:
expected_keys.append('encoder_outputs')
if has_masked_lm_positions:
expected_keys.append('mlm_logits')
self.assertSameElements(outputs.keys(), expected_keys)
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list())
if has_masked_lm_positions:
self.assertAllEqual(expected_lm_shape,
outputs['mlm_logits'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape,
outputs['sequence_output'].shape.as_list())
expected_pooled_output_shape = [None, hidden_size]
self.assertAllEqual(expected_pooled_output_shape,
outputs['pooled_output'].shape.as_list())
def test_multiple_cls_outputs(self):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
dict_outputs=True)
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network,
classification_heads=[layers.MultiClsHeads(
inner_dim=5, cls_list=[('foo', 2), ('bar', 3)])])
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
masked_lm_positions=tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs)
self.assertEqual(outputs['foo'].shape.as_list(), [None, 2])
self.assertEqual(outputs['bar'].shape.as_list(), [None, 3])
def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import networks
......@@ -32,17 +28,20 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and
The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
table via a `get_embedding_table` method.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
output: The output style for this network. Can be either `logit`' or
`predictions`.
"""
def __init__(self,
......@@ -50,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -65,16 +57,25 @@ class BertSpanLabeler(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs)
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
# The input network (typically a transformer model) may get outputs from all
# layers. When this case happens, we retrieve the last layer output.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
# This is an instance variable for ease of access to the underlying task
# network.
self.span_labeling = networks.SpanLabeling(
span_labeling = networks.SpanLabeling(
input_width=sequence_output.shape[-1],
initializer=initializer,
output=output,
name='span_labeling')
start_logits, end_logits = self.span_labeling(sequence_output)
start_logits, end_logits = span_labeling(sequence_output)
# Use identity layers wrapped in lambdas to explicitly name the output
# tensors. This allows us to use string-keyed dicts in Keras fit/predict/
......@@ -88,15 +89,36 @@ class BertSpanLabeler(tf.keras.Model):
logits = [start_logits, end_logits]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs)
self._network = network
config_dict = {
'network': network,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.span_labeling = span_labeling
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT trainer network."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -30,13 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler
@keras_parameterized.run_all_keras_modes
class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_bert_trainer(self):
@parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
......@@ -59,9 +57,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
......@@ -80,8 +76,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
......@@ -100,7 +95,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,18 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""BERT token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text')
class BertTokenClassifier(tf.keras.Model):
......@@ -36,15 +30,21 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
argument.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
table via a `get_embedding_table` method.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
output: The output style for this network. Can be either `logits` or
`predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
"""
def __init__(self,
......@@ -53,15 +53,8 @@ class BertTokenClassifier(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
dropout_rate=0.1,
output_encoder_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -70,27 +63,70 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs)
sequence_output = tf.keras.layers.Dropout(
rate=dropout_rate)(sequence_output)
self.classifier = networks.TokenClassification(
input_width=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
predictions = self.classifier(sequence_output)
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
sequence_output)
classifier = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')
logits = classifier(sequence_output)
if output == 'logits':
output_tensors = {'logits': logits}
elif output == 'predictions':
output_tensors = {
'predictions': tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
}
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
if output_encoder_outputs:
output_tensors['encoder_outputs'] = sequence_output
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
inputs=inputs, outputs=output_tensors, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
'output_encoder_outputs': output_encoder_outputs
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
self.logits = logits
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for BERT token classifier."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -30,19 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self):
@parameterized.parameters((True, True), (False, False))
def test_bert_trainer(self, dict_outputs, output_encoder_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
hidden_size = 768
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=dict_outputs,
hidden_size=hidden_size)
# Create a BERT trainer with the created network.
num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network,
num_classes=num_classes)
num_classes=num_classes,
output_encoder_outputs=output_encoder_outputs)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -50,19 +54,25 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids])
outputs = bert_trainer_model([word_ids, mask, type_ids])
if output_encoder_outputs:
logits = outputs['logits']
encoder_outputs = outputs['encoder_outputs']
self.assertAllEqual(encoder_outputs.shape.as_list(),
[None, sequence_length, hidden_size])
else:
logits = outputs['logits']
# Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape,
sequence_outs.shape.as_list())
self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
......@@ -82,8 +92,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf.keras.Model):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Args:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either `logits` or
`predictions`. If set to `predictions`, it will output the embedding
producted by transformer network.
"""
def __init__(self,
network: tf.keras.Model,
max_seq_length: int = 32,
normalize: bool = True,
logit_scale: float = 1.0,
logit_margin: float = 0.0,
output: str = 'logits',
**kwargs) -> None:
if output == 'logits':
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
else:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
left_outputs = network(left_inputs)
if isinstance(left_outputs, list):
left_sequence_output, left_encoded = left_outputs
else:
left_sequence_output = left_outputs['sequence_output']
left_encoded = left_outputs['pooled_output']
if normalize:
left_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
left_encoded)
if output == 'logits':
right_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
right_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
right_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids]
right_outputs = network(right_inputs)
if isinstance(right_outputs, list):
_, right_encoded = right_outputs
else:
right_encoded = right_outputs['pooled_output']
if normalize:
right_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
right_encoded)
dot_products = layers.MatMulWithMargin(
logit_scale=logit_scale,
logit_margin=logit_margin,
name='dot_product')
inputs = [
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
]
left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = dict(left_logits=left_logits, right_logits=right_logits)
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = dict(
sequence_output=left_sequence_output, pooled_output=left_encoded)
else:
raise ValueError('output type %s is not supported' % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
config_dict = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.network = network
def get_config(self):
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.network)
return items
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dual encoder network."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import dual_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class DualEncoderTest(keras_parameterized.TestCase):
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder(self, hidden_size, output):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the dual encoder model.
vocab_size = 100
sequence_length = 512
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
hidden_size=hidden_size,
sequence_length=sequence_length,
dict_outputs=True)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
left_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
if output == 'logits':
outputs = dual_encoder_model([
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
])
_ = outputs['left_logits']
elif output == 'predictions':
outputs = dual_encoder_model([left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape.
expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape,
outputs['sequence_output'].shape.as_list())
left_encoded = outputs['pooled_output']
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.)
sequence_length = 2
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
# Invoke the model model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
if output == 'logits':
_ = dual_encoder_model(
[word_ids, mask, type_ids, word_ids, mask, type_ids])
elif output == 'predictions':
_ = dual_encoder_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
"""Validate that the dual encoder model can be serialized / deserialized."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.)
sequence_length = 32
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network. (Note that all the
# args are different, so we can catch any serialization mismatches.)
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output='predictions')
# Create another dual encoder model via serialization and deserialization.
config = dual_encoder_model.get_config()
new_dual_encoder = dual_encoder.DualEncoder.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_dual_encoder.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(dual_encoder_model.get_config(),
new_dual_encoder.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,15 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import tensorflow as tf
from official.modeling import tf_utils
......@@ -39,7 +36,10 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
Arguments:
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
discriminator_network: A transformer network for discriminator, this network
......@@ -47,15 +47,13 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or
'predictions'.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
......@@ -65,8 +63,6 @@ class ElectraPretrainer(tf.keras.Model):
discriminator_network,
vocab_size,
num_classes,
sequence_length,
last_hidden_dim,
num_token_predictions,
mlm_activation=None,
mlm_initializer='glorot_uniform',
......@@ -79,8 +75,6 @@ class ElectraPretrainer(tf.keras.Model):
'discriminator_network': discriminator_network,
'vocab_size': vocab_size,
'num_classes': num_classes,
'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
......@@ -94,8 +88,6 @@ class ElectraPretrainer(tf.keras.Model):
self.discriminator_network = discriminator_network
self.vocab_size = vocab_size
self.num_classes = num_classes
self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer
......@@ -108,10 +100,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type,
name='generator_masked_lm')
self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim,
inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=mlm_initializer,
name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer)
......@@ -123,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor
indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating
(3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating
(4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
"""
input_word_ids = inputs['input_word_ids']
......@@ -138,14 +135,11 @@ class ElectraPretrainer(tf.keras.Model):
masked_lm_positions = inputs['masked_lm_positions']
### Generator ###
sequence_output, cls_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])
sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
sentence_outputs = self.classification(sequence_output)
......@@ -156,16 +150,17 @@ class ElectraPretrainer(tf.keras.Model):
### Discriminator ###
disc_input = fake_data['inputs']
disc_label = fake_data['is_fake_tokens']
disc_sequence_output, _ = self.discriminator_network([
disc_sequence_output = self.discriminator_network([
disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids']
])
])['sequence_output']
# The discriminator encoder network may get outputs from all layers.
if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output)
disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = {
......@@ -181,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator.
Args:
inputs: A dict of all inputs, same as the input of call() function
inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications
......@@ -214,6 +209,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens
}
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self):
return self._config
......@@ -226,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence.
Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
updates: A tensor of size batch_size*seq_len(*depth)
positions: A [batch_size, n_positions] tensor
sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
tensor.
updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth]
tensor of "sequence" with elements at "positions" replaced by the values
at "updates". Updates to index 0 are ignored. If there are duplicated
positions the update is only applied once.
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were
updated_sequence: A `[batch_size, seq_len]` or
`[batch_size, seq_len, depth]` tensor of "sequence" with elements at
"positions" replaced by the values at "updates". Updates to index 0 are
ignored. If there are duplicated positions the update is only
applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated.
"""
shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
......@@ -288,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
indicating the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
indicating the true word id in each masked position.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot
sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
tensor indicating the sampled word id in each masked position.
"""
if disallow is not None:
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for ELECTRA pre trainer network."""
import tensorflow as tf
......@@ -35,10 +31,16 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer.
vocab_size = 100
sequence_length = 512
test_generator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_discriminator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
test_generator_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=True)
test_discriminator_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=True)
# Create a ELECTRA trainer with the created network.
num_classes = 3
......@@ -48,8 +50,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
discriminator_network=test_discriminator_network,
vocab_size=vocab_size,
num_classes=num_classes,
sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions,
disallow_correct=True)
......@@ -89,10 +89,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3, dict_outputs=True)
# Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
......@@ -101,7 +101,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100,
num_classes=2,
sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model.
......@@ -127,10 +126,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the ELECTRA trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......@@ -140,7 +139,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100,
num_classes=2,
sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
"""
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
EOS_ID = 1
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continuous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def __init__(self,
vocab_size=33708,
embedding_width=512,
dropout_rate=0.0,
padded_decode=False,
decode_max_length=None,
extra_decode_length=0,
beam_size=4,
alpha=0.6,
encoder_layer=None,
decoder_layer=None,
eos_id=EOS_ID,
**kwargs):
"""Initialize layers to build Transformer model.
Args:
vocab_size: Size of vocabulary.
embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
decode_max_length: maximum number of steps to decode a sequence.
extra_decode_length: Beam search will run extra steps to decode.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._dropout_rate = dropout_rate
self._padded_decode = padded_decode
self._decode_max_length = decode_max_length
self._extra_decode_length = extra_decode_length
self._beam_size = beam_size
self._alpha = alpha
self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer(
mean=0., stddev=self._embedding_width**-0.5),
scale_factor=self._embedding_width**0.5)
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding(
hidden_size=self._embedding_width)
self.encoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.decoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"hidden_size": self._embedding_width,
"dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length,
"eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size,
"alpha": self._alpha,
"encoder_layer": self.encoder_layer,
"decoder_layer": self.decoder_layer
}
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Returns:
If targets is defined, then return logits for each word in the target
sequence, which is a float tensor with shape
`(batch_size, target_length, vocab_size)`. If target is `None`, then
generate output sequence one token at a time and
returns a dictionary {
outputs: `(batch_size, decoded_length)`
scores: `(batch_size, 1)`}
Even when `float16` is used, the output tensor(s) are always `float32`.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
sources = inputs["inputs"]
targets = inputs.get("targets", None)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=sources.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
if self._padded_decode:
max_decode_length = self._decode_max_length
else:
max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0]
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache = {
str(layer): {
"key":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype),
"value":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype)
} for layer in range(self.decoder_layer.num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=self._eos_id,
padded_decode=self._padded_decode,
dtype=self.compute_dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
decoder_self_attention_mask = tf.linalg.band_part(
tf.ones([max_decode_length, max_decode_length],
dtype=self.compute_dtype), -1, 0)
decoder_self_attention_mask = tf.reshape(
decoder_self_attention_mask, [1, max_decode_length, max_decode_length])
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape `(batch_size *
beam_size, i + 1)`.
i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
# decoder_input = self.embedding_softmax_layer(decoder_input)
source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i]
if self._padded_decode:
# indexing does not work on TPU.
bias_shape = decoder_self_attention_mask.shape.as_list()
self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else:
self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = cache.get("encoder_decoder_attention_mask")
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
decoder_outputs = self.decoder_layer(
decoder_input,
cache.get("encoder_outputs"),
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask,
cache=cache,
decode_loop_step=i if self._padded_decode else None)
decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
return symbols_to_logits_fn
class TransformerEncoder(tf.keras.layers.Layer):
"""Transformer encoder.
Transformer encoder is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
super(TransformerEncoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask])
output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor)
return output_tensor
class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder.
Like the encoder, the decoder is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.decoder_layers = []
for i in range(self.num_layers):
self.decoder_layers.append(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
target,
memory,
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape `(batch_size, i, value_channels)`},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder.
float32 tensor with shape `(batch_size, target_length, hidden_size`).
"""
output_tensor = target
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
return self.output_normalization(output_tensor)
def attention_initializer(hidden_size):
"""Initializer for attention layers in Seq2SeqTransformer."""
hidden_size = int(hidden_size)
limit = math.sqrt(6.0 / (hidden_size + hidden_size))
return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
targets=np.zeros((batch_size, 8), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
@parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode):
decode_max_length = 10
model = self._build_model(padded_decode, decode_max_length)
class SaveModule(tf.Module):
def __init__(self, model):
super(SaveModule, self).__init__()
self.model = model
@tf.function
def serve(self, inputs):
return self.model.call(dict(inputs=inputs))
save_module = SaveModule(model)
if padded_decode:
tensor_shape = (4, 10)
else:
tensor_shape = (None, None)
signatures = dict(
serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
class XLNetMaskedLM(tf.keras.layers.Layer):
"""XLNet pretraining head."""
def __init__(self,
vocab_size: int,
hidden_size: int,
initializer: str = 'glorot_uniform',
activation: str = 'gelu',
name=None,
**kwargs):
super().__init__(name=name, **kwargs)
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._initializer = initializer
self._activation = activation
def build(self, input_shape):
self.dense = tf.keras.layers.Dense(
units=self._hidden_size,
activation=self._activation,
kernel_initializer=self._initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super().build(input_shape)
def call(self,
sequence_data: tf.Tensor,
embedding_table: tf.Tensor):
lm_data = self.dense(sequence_data)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
return logits
def get_config(self) -> Mapping[str, Any]:
config = {
'vocab_size':
self._vocab_size,
'hidden_size':
self._hidden_size,
'initializer':
self._initializer
}
base_config = super(XLNetMaskedLM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetPretrainer(tf.keras.Model):
"""XLNet-based pretrainer.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
mlm_activation: The activation (if any) to use in the Masked LM network. If
None, then no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Defaults
to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
mlm_activation=None,
mlm_initializer='glorot_uniform',
name: str = None,
**kwargs):
super().__init__(name=name, **kwargs)
self._config = {
'network': network,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
}
self._network = network
self._hidden_size = network.get_config()['hidden_size']
self._vocab_size = network.get_config()['vocab_size']
self._activation = mlm_activation
self._initializer = mlm_initializer
self._masked_lm = XLNetMaskedLM(
vocab_size=self._vocab_size,
hidden_size=self._hidden_size,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
masked_tokens = inputs['masked_tokens']
permutation_mask = inputs['permutation_mask']
target_mapping = inputs['target_mapping']
state = inputs.get('state', None)
attention_output, state = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=None,
state=state,
permutation_mask=permutation_mask,
target_mapping=target_mapping,
masked_tokens=masked_tokens)
embedding_table = self._network.get_embedding_lookup_table()
mlm_outputs = self._masked_lm(
sequence_data=attention_output,
embedding_table=embedding_table)
return mlm_outputs, state
def get_config(self) -> Mapping[str, Any]:
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetClassifier(tf.keras.Model):
"""Classifier model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
num_classes: int,
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
**kwargs):
super().__init__(**kwargs)
self._network = network
self._initializer = initializer
self._summary_type = summary_type
self._num_classes = num_classes
self._config = {
'network': network,
'initializer': initializer,
'num_classes': num_classes,
'summary_type': summary_type,
'dropout_rate': dropout_rate,
}
if summary_type == 'last':
cls_token_idx = -1
elif summary_type == 'first':
cls_token_idx = 0
else:
raise ValueError('Invalid summary type provided: %s.' % summary_type)
self.classifier = layers.ClassificationHead(
inner_dim=network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx,
name='sentence_prediction')
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_mask = tf.cast(inputs['input_mask'], tf.float32)
state = inputs.get('mems', None)
attention_output, _ = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
state=state)
logits = self.classifier(attention_output)
return logits
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int = 5,
end_n_top: int = 5,
dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._config = {
'network': network,
'start_n_top': start_n_top,
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
'span_labeling_activation': span_labeling_activation,
'initializer': initializer,
}
network_config = network.get_config()
try:
input_width = network_config['inner_size']
self._xlnet_base = True
except KeyError:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width = network_config['intermediate_size']
self._xlnet_base = False
self._network = network
self._initializer = initializer
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self._dropout_rate = dropout_rate
self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling(
input_width=input_width,
start_n_top=self._start_n_top,
end_n_top=self._end_n_top,
activation=self._activation,
dropout_rate=self._dropout_rate,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
input_mask = inputs['input_mask']
class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None)
if self._xlnet_base:
attention_output, _ = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=input_mask)
else:
network_output_dict = self._network(dict(
input_word_ids=input_word_ids,
input_type_ids=input_type_ids,
input_mask=input_mask))
attention_output = network_output_dict['sequence_output']
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
paragraph_mask=paragraph_mask,
start_positions=start_positions)
return outputs
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for XLNet classifier network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import xlnet
def _get_xlnet_base() -> tf.keras.layers.Layer:
"""Returns a trivial base XLNet model."""
return networks.XLNetBase(
vocab_size=100,
num_layers=2,
hidden_size=4,
num_attention_heads=2,
head_size=2,
inner_size=2,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type='bi',
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
reuse_length=0,
inner_activation='relu')
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class XLNetMaskedLMTest(keras_parameterized.TestCase):
def test_xlnet_masked_lm_head(self):
hidden_size = 10
seq_length = 8
batch_size = 2
masked_lm = xlnet.XLNetMaskedLM(vocab_size=10,
hidden_size=hidden_size,
initializer='glorot_uniform')
sequence_data = np.random.uniform(size=(batch_size, seq_length))
embedding_table = np.random.uniform(size=(hidden_size, hidden_size))
mlm_output = masked_lm(sequence_data, embedding_table)
self.assertAllClose(mlm_output.shape, (batch_size, hidden_size))
@keras_parameterized.run_all_keras_modes
class XLNetPretrainerTest(keras_parameterized.TestCase):
def test_xlnet_trainer(self):
"""Validates that the Keras object can be created."""
seq_length = 4
num_predictions = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)
inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_mask'),
permutation_mask=tf.keras.layers.Input(
shape=(seq_length, seq_length,), dtype=tf.int32,
name='permutation_mask'),
target_mapping=tf.keras.layers.Input(
shape=(num_predictions, seq_length), dtype=tf.int32,
name='target_mapping'),
masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='masked_tokens'))
logits, _ = xlnet_trainer_model(inputs)
# [None, hidden_size, vocab_size]
expected_output_shape = [None, 4, 100]
self.assertAllEqual(expected_output_shape, logits.shape.as_list())
def test_xlnet_tensor_call(self):
"""Validates that the Keras object can be invoked."""
seq_length = 4
batch_size = 2
num_predictions = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('int32'),
target_mapping=np.random.randint(
10, size=(num_predictions, seq_length), dtype='int32'),
masked_tokens=np.random.randint(
10, size=sequence_shape, dtype='int32'))
xlnet_trainer_model(inputs)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(
network=xlnet_base,
mlm_activation='gelu',
mlm_initializer='random_normal')
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetPretrainer.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
@keras_parameterized.run_all_keras_modes
class XLNetClassifierTest(keras_parameterized.TestCase):
def test_xlnet_trainer(self):
"""Validate that the Keras object can be created."""
num_classes = 2
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=num_classes,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_mask'),
permutation_mask=tf.keras.layers.Input(
shape=(seq_length, seq_length,), dtype=tf.int32,
name='permutation_mask'),
masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='masked_tokens'))
logits = xlnet_trainer_model(inputs)
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
@parameterized.parameters(1, 2)
def test_xlnet_tensor_call(self, num_classes):
"""Validates that the Keras object can be invoked."""
seq_length = 4
batch_size = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=num_classes,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('int32'),
masked_tokens=np.random.randint(
10, size=sequence_shape, dtype='int32'))
xlnet_trainer_model(inputs)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=2,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetClassifier.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
@keras_parameterized.run_all_keras_modes
class XLNetSpanLabelerTest(keras_parameterized.TestCase):
def test_xlnet_trainer(self):
"""Validate that the Keras object can be created."""
top_n = 2
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=top_n,
end_n_top=top_n,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_mask'),
paragraph_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='paragraph_mask'),
class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='start_positions'))
outputs = xlnet_trainer_model(inputs)
self.assertIsInstance(outputs, dict)
# Test tensor value calls for the created model.
batch_size = 2
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
paragraph_mask=np.random.randint(
1, size=(sequence_shape)).astype('int32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32))
common_keys = {
'start_logits', 'end_logits', 'start_predictions', 'end_predictions',
'class_logits',
}
inference_keys = {
'start_top_predictions', 'end_top_predictions', 'start_top_index',
'end_top_index',
}
outputs = xlnet_trainer_model(inputs)
self.assertSetEqual(common_keys | inference_keys, set(outputs.keys()))
outputs = xlnet_trainer_model(inputs, training=True)
self.assertIsInstance(outputs, dict)
self.assertSetEqual(common_keys, set(outputs.keys()))
self.assertIsInstance(outputs, dict)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=2,
end_n_top=2,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetSpanLabeler.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Networks
Networks are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. It
encapsulates common network structures like a classification head
or a transformer encoder into an easily handled object with a
standardized configuration.
Networks are combinations of `tf.keras` layers (and possibly other networks).
They are `tf.keras` models that would not be trained alone. It encapsulates
common network structures like a transformer encoder into an easily
handled object with a standardized configuration.
* [`TransformerEncoder`](transformer_encoder.py) implements a bi-directional
* [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep
Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding lookups,
transformer layers and pooling layer.
Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805).
It includes the embedding lookups, transformer layers and pooling layer.
* [`AlbertTransformerEncoder`](albert_transformer_encoder.py) implements a
* [`AlbertEncoder`](albert_encoder.py) implements a
Transformer-encoder described in the paper ["ALBERT: A Lite BERT for
Self-supervised Learning of Language Representations]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
into two smaller matrices and shares parameters across layers.
Self-supervised Learning of Language Representations"]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805),
ALBERT refactorizes embedding parameters into two smaller matrices and shares
parameters across layers.
* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the
MobileBERT network described in the paper ["MobileBERT: a Compact Task-Agnostic
BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984).
* [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
* [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head.
* [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an
embedding network that supports packed sequences and position ids.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler
(that is, a prediction head that can predict one start and end index per batch
item) based on a single dense hidden layer. It can be used in the SQuAD task.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet:
Generalized Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). It includes embedding lookups,
relative position encodings, mask computations, segment matrix computations and
Transformer XL layers using one or two stream relative self-attention.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment