Commit 44fc54a1 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement XLNet QA model.

PiperOrigin-RevId: 337134894
parent 41bcd7d0
......@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
......@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text')
......@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model):
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int,
end_n_top: int,
dropout_rate: float,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._config = {
'network': network,
'start_n_top': start_n_top,
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
'span_labeling_activation': span_labeling_activation,
'initializer': initializer,
}
self._network = network
self._initializer = initializer
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self._dropout_rate = dropout_rate
self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling(
input_width=network.get_config()['inner_size'],
start_n_top=self._start_n_top,
end_n_top=self._end_n_top,
activation=self._activation,
dropout_rate=self._dropout_rate,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids']
segment_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
class_index = tf.reshape(inputs['class_index'], [-1])
position_mask = inputs['position_mask']
start_positions = inputs['start_positions']
attention_output, new_states = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask)
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
position_mask=position_mask,
start_positions=start_positions)
return outputs, new_states
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
new_xlnet_trainer_model.get_config())
@keras_parameterized.run_all_keras_modes
class XLNetSpanLabelerTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 2)
def test_xlnet_trainer(self, top_n):
"""Validate that the Keras object can be created."""
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=top_n,
end_n_top=top_n,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
inputs = dict(
input_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'),
position_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='position_mask'),
class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='start_positions'))
outputs, _ = xlnet_trainer_model(inputs)
self.assertIsInstance(outputs, dict)
# Test tensor value calls for the created model.
batch_size = 2
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
position_mask=np.random.randint(
1, size=(sequence_shape)).astype('float32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32))
outputs, _ = xlnet_trainer_model(inputs)
expected_inference_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits',
'start_top_index', 'end_top_index',
}
self.assertSetEqual(expected_inference_keys, set(outputs.keys()))
outputs, _ = xlnet_trainer_model(inputs, training=True)
self.assertIsInstance(outputs, dict)
expected_train_keys = {
'start_log_probs', 'end_log_probs', 'class_logits'
}
self.assertSetEqual(expected_train_keys, set(outputs.keys()))
self.assertIsInstance(outputs, dict)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=2,
end_n_top=2,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetSpanLabeler.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
if __name__ == '__main__':
tf.test.main()
......@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
positions, and then uses either the true start positions (if training) or
beam search to predict the end positions.
**Note: `compute_with_beam_search` will not work with the Functional API
(https://www.tensorflow.org/guide/keras/functional).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start.
......@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
activation=activation,
name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12,
......@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/answer_logits/output')
def end_logits(self, inputs):
"""Computes the end logits."""
"""Computes the end logits.
Input shapes into the inner, layer norm, output layers should match.
During training, inputs shape should be
[batch_size, seq_length, input_width].
During inference, input shapes should be
[batch_size, seq_length, start_n_top, input_width].
Args:
inputs: The input for end logits.
Returns:
Calculated end logits.
"""
if len(tf.shape(inputs)) == 3:
# inputs: [B, S, H] -> [B, S, 1, H]
inputs = tf.expand_dims(inputs, axis=2)
end_logits = self.end_logits_inner_dense(inputs)
end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_output_dense(end_logits)
end_logits = tf.squeeze(end_logits)
if tf.rank(end_logits) > 2:
# shape = [batch_size, seq_length, start_n_top]
# shape = [B, S, K] -> [B, K, S]
end_logits = tf.transpose(end_logits, [0, 2, 1])
return end_logits
......
......@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
}
self.assertSetEqual(expected_keys, set(output.keys()))
def test_functional_model_invocation(self):
"""Tests basic invocation of this layer wrapped by a Functional model."""
def test_subclass_invocation(self):
"""Tests basic invocation of this layer wrapped in a subclass."""
seq_length = 8
hidden_size = 4
batch_size = 2
......@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
dtype=tf.float32)
class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.int32)
layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size,
......@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
position_mask = tf.random.uniform(
shape=(batch_size, seq_length), dtype=tf.float32)
class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
start_positions = tf.random.uniform(shape=(batch_size,), dtype=tf.float32)
start_positions = tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32)
inputs = dict(sequence_data=sequence_data,
position_mask=position_mask,
......@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output = model(inputs)
self.assertIsInstance(output, dict)
# Test `call` with training flag.
output = model.call(inputs, training=True)
self.assertIsInstance(output, dict)
# Test `call` without training flag.
output = model.call(inputs, training=False)
output = model(inputs, training=False)
self.assertIsInstance(output, dict)
# Test `call` with training flag.
# Note: this fails due to incompatibility with the functional API.
with self.assertRaisesRegexp(AssertionError,
'Could not compute output KerasTensor'):
model(inputs, training=True)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
network = span_labeling.XLNetSpanLabeling(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models that are compatible with TF 2.x."""
import tensorflow as tf
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.xlnet import xlnet_config
def _get_initializer(
initialization_method: str,
initialization_range: float,
initialization_std: float) -> tf.keras.initializers.Initializer:
"""Gets variable initializer."""
if initialization_method == 'uniform':
initializer = tf.keras.initializers.RandomUniform(
minval=-initialization_range, maxval=initialization_range)
elif initialization_method == 'normal':
initializer = tf.keras.initializers.RandomNormal(stddev=initialization_std)
else:
raise ValueError('Initializer {} not supported'.format(
initialization_method))
return initializer
def get_xlnet_base(model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
attention_type: str,
two_stream: bool,
use_cls_mask: bool) -> tf.keras.Model:
"""Gets an 'XLNetBase' object.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
attention_type: the attention type for the base XLNet model, "uni" or "bi".
two_stream: whether or not to use two strema attention.
use_cls_mask: whether or not cls mask is included in the input sequences.
Returns:
An XLNetBase object.
"""
initializer = _get_initializer(initialization_method=run_config.init_method,
initialization_range=run_config.init_range,
initialization_std=run_config.init_std)
kwargs = dict(
vocab_size=model_config.n_token,
num_layers=model_config.n_layer,
hidden_size=model_config.d_model,
num_attention_heads=model_config.n_head,
head_size=model_config.d_head,
inner_size=model_config.d_inner,
dropout_rate=run_config.dropout,
attention_dropout_rate=run_config.dropout_att,
attention_type=attention_type,
bi_data=run_config.bi_data,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=not model_config.untie_r,
memory_length=run_config.mem_len,
clamp_length=run_config.clamp_len,
reuse_length=run_config.reuse_len,
inner_activation=model_config.ff_activation,
use_cls_mask=use_cls_mask)
return networks.XLNetBase(**kwargs)
def classifier_model(
model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
num_labels: int,
final_layer_initializer: tf.keras.initializers.Initializer = None
) -> tf.keras.Model:
"""Returns a TF2 Keras XLNet classifier model.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
num_labels: integer, the number of classes.
final_layer_initializer: Initializer for final dense layer. If `None`, then
it defaults to the one specified in `run_config`.
Returns:
Combined prediction model inputs -> (one-hot labels)
XLNet sub-model inputs -> (xlnet_outputs)
where inputs are:
(words, segments, mask, permutation mask,
target mapping, masked tokens)
"""
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.RandomNormal(
mean=0., stddev=.02)
xlnet_base = get_xlnet_base(
model_config=model_config,
run_config=run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
return models.XLNetClassifier(
network=xlnet_base,
num_classes=num_labels,
dropout_rate=run_config.dropout,
summary_type='last',
initializer=initializer), xlnet_base
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_models
class XLNetModelsTest(tf.test.TestCase):
def setUp(self):
super(XLNetModelsTest, self).setUp()
self._xlnet_test_config = xlnet_config.XLNetConfig(
args_dict=dict(
n_layer=2,
d_model=4,
n_head=1,
d_head=2,
d_inner=4,
ff_activation='gelu',
untie_r=True,
n_token=32000))
self._run_config = xlnet_config.RunConfig(
is_training=True,
use_tpu=False,
dropout=0.0,
dropout_att=0.0,
init_method='normal',
init_range=0.1,
init_std=0.02,
mem_len=0,
reuse_len=4,
bi_data=False,
clamp_len=-1,
same_length=False)
def test_xlnet_base(self):
xlnet_base = xlnet_models.get_xlnet_base(
model_config=self._xlnet_test_config,
run_config=self._run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
def test_xlnet_classifier(self):
xlnet_classifier, xlnet_base = xlnet_models.classifier_model(
model_config=self._xlnet_test_config,
run_config=self._run_config,
num_labels=2)
self.assertIsInstance(xlnet_classifier, tf.keras.Model)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment