Commit 7a6a8741 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement XLNet QA model.

PiperOrigin-RevId: 337134894
parent dd04e547
...@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder ...@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import * from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
...@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union ...@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
...@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model): ...@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int,
end_n_top: int,
dropout_rate: float,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._config = {
'network': network,
'start_n_top': start_n_top,
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
'span_labeling_activation': span_labeling_activation,
'initializer': initializer,
}
self._network = network
self._initializer = initializer
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self._dropout_rate = dropout_rate
self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling(
input_width=network.get_config()['inner_size'],
start_n_top=self._start_n_top,
end_n_top=self._end_n_top,
activation=self._activation,
dropout_rate=self._dropout_rate,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids']
segment_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
class_index = tf.reshape(inputs['class_index'], [-1])
position_mask = inputs['position_mask']
start_positions = inputs['start_positions']
attention_output, new_states = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask)
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
position_mask=position_mask,
start_positions=start_positions)
return outputs, new_states
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
...@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
new_xlnet_trainer_model.get_config()) new_xlnet_trainer_model.get_config())
@keras_parameterized.run_all_keras_modes
class XLNetSpanLabelerTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 2)
def test_xlnet_trainer(self, top_n):
"""Validate that the Keras object can be created."""
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=top_n,
end_n_top=top_n,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
inputs = dict(
input_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'),
position_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='position_mask'),
class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='start_positions'))
outputs, _ = xlnet_trainer_model(inputs)
self.assertIsInstance(outputs, dict)
# Test tensor value calls for the created model.
batch_size = 2
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
position_mask=np.random.randint(
1, size=(sequence_shape)).astype('float32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32))
outputs, _ = xlnet_trainer_model(inputs)
expected_inference_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits',
'start_top_index', 'end_top_index',
}
self.assertSetEqual(expected_inference_keys, set(outputs.keys()))
outputs, _ = xlnet_trainer_model(inputs, training=True)
self.assertIsInstance(outputs, dict)
expected_train_keys = {
'start_log_probs', 'end_log_probs', 'class_logits'
}
self.assertSetEqual(expected_train_keys, set(outputs.keys()))
self.assertIsInstance(outputs, dict)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetSpanLabeler(
network=xlnet_base,
start_n_top=2,
end_n_top=2,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
span_labeling_activation='tanh',
dropout_rate=0.1)
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetSpanLabeler.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
positions, and then uses either the true start positions (if training) or positions, and then uses either the true start positions (if training) or
beam search to predict the end positions. beam search to predict the end positions.
**Note: `compute_with_beam_search` will not work with the Functional API
(https://www.tensorflow.org/guide/keras/functional).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start. start_n_top: Beam size for span start.
...@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self.end_logits_inner_dense = tf.keras.layers.Dense( self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width, units=input_width,
kernel_initializer=initializer, kernel_initializer=initializer,
activation=activation,
name='predictions/transform/end_logits/inner') name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization( self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, axis=-1, epsilon=1e-12,
...@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/answer_logits/output') name='predictions/transform/answer_logits/output')
def end_logits(self, inputs): def end_logits(self, inputs):
"""Computes the end logits.""" """Computes the end logits.
Input shapes into the inner, layer norm, output layers should match.
During training, inputs shape should be
[batch_size, seq_length, input_width].
During inference, input shapes should be
[batch_size, seq_length, start_n_top, input_width].
Args:
inputs: The input for end logits.
Returns:
Calculated end logits.
"""
if len(tf.shape(inputs)) == 3:
# inputs: [B, S, H] -> [B, S, 1, H]
inputs = tf.expand_dims(inputs, axis=2)
end_logits = self.end_logits_inner_dense(inputs) end_logits = self.end_logits_inner_dense(inputs)
end_logits = self.end_logits_layer_norm(end_logits) end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_output_dense(end_logits) end_logits = self.end_logits_output_dense(end_logits)
end_logits = tf.squeeze(end_logits) end_logits = tf.squeeze(end_logits)
if tf.rank(end_logits) > 2: if tf.rank(end_logits) > 2:
# shape = [batch_size, seq_length, start_n_top] # shape = [B, S, K] -> [B, K, S]
end_logits = tf.transpose(end_logits, [0, 2, 1]) end_logits = tf.transpose(end_logits, [0, 2, 1])
return end_logits return end_logits
......
...@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
} }
self.assertSetEqual(expected_keys, set(output.keys())) self.assertSetEqual(expected_keys, set(output.keys()))
def test_functional_model_invocation(self): def test_subclass_invocation(self):
"""Tests basic invocation of this layer wrapped by a Functional model.""" """Tests basic invocation of this layer wrapped in a subclass."""
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
batch_size = 2 batch_size = 2
...@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
dtype=tf.float32) dtype=tf.float32)
class_index = tf.keras.Input(shape=(), dtype=tf.uint8) class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32) position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.float32) start_positions = tf.keras.Input(shape=(), dtype=tf.int32)
layer = span_labeling.XLNetSpanLabeling( layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size, input_width=hidden_size,
...@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
position_mask = tf.random.uniform( position_mask = tf.random.uniform(
shape=(batch_size, seq_length), dtype=tf.float32) shape=(batch_size, seq_length), dtype=tf.float32)
class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8) class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
start_positions = tf.random.uniform(shape=(batch_size,), dtype=tf.float32) start_positions = tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32)
inputs = dict(sequence_data=sequence_data, inputs = dict(sequence_data=sequence_data,
position_mask=position_mask, position_mask=position_mask,
...@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output = model(inputs) output = model(inputs)
self.assertIsInstance(output, dict) self.assertIsInstance(output, dict)
# Test `call` with training flag.
output = model.call(inputs, training=True)
self.assertIsInstance(output, dict)
# Test `call` without training flag. # Test `call` without training flag.
output = model.call(inputs, training=False) output = model(inputs, training=False)
self.assertIsInstance(output, dict) self.assertIsInstance(output, dict)
# Test `call` with training flag.
# Note: this fails due to incompatibility with the functional API.
with self.assertRaisesRegexp(AssertionError,
'Could not compute output KerasTensor'):
model(inputs, training=True)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
network = span_labeling.XLNetSpanLabeling( network = span_labeling.XLNetSpanLabeling(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models that are compatible with TF 2.x."""
import tensorflow as tf
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.xlnet import xlnet_config
def _get_initializer(
initialization_method: str,
initialization_range: float,
initialization_std: float) -> tf.keras.initializers.Initializer:
"""Gets variable initializer."""
if initialization_method == 'uniform':
initializer = tf.keras.initializers.RandomUniform(
minval=-initialization_range, maxval=initialization_range)
elif initialization_method == 'normal':
initializer = tf.keras.initializers.RandomNormal(stddev=initialization_std)
else:
raise ValueError('Initializer {} not supported'.format(
initialization_method))
return initializer
def get_xlnet_base(model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
attention_type: str,
two_stream: bool,
use_cls_mask: bool) -> tf.keras.Model:
"""Gets an 'XLNetBase' object.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
attention_type: the attention type for the base XLNet model, "uni" or "bi".
two_stream: whether or not to use two strema attention.
use_cls_mask: whether or not cls mask is included in the input sequences.
Returns:
An XLNetBase object.
"""
initializer = _get_initializer(initialization_method=run_config.init_method,
initialization_range=run_config.init_range,
initialization_std=run_config.init_std)
kwargs = dict(
vocab_size=model_config.n_token,
num_layers=model_config.n_layer,
hidden_size=model_config.d_model,
num_attention_heads=model_config.n_head,
head_size=model_config.d_head,
inner_size=model_config.d_inner,
dropout_rate=run_config.dropout,
attention_dropout_rate=run_config.dropout_att,
attention_type=attention_type,
bi_data=run_config.bi_data,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=not model_config.untie_r,
memory_length=run_config.mem_len,
clamp_length=run_config.clamp_len,
reuse_length=run_config.reuse_len,
inner_activation=model_config.ff_activation,
use_cls_mask=use_cls_mask)
return networks.XLNetBase(**kwargs)
def classifier_model(
model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
num_labels: int,
final_layer_initializer: tf.keras.initializers.Initializer = None
) -> tf.keras.Model:
"""Returns a TF2 Keras XLNet classifier model.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
num_labels: integer, the number of classes.
final_layer_initializer: Initializer for final dense layer. If `None`, then
it defaults to the one specified in `run_config`.
Returns:
Combined prediction model inputs -> (one-hot labels)
XLNet sub-model inputs -> (xlnet_outputs)
where inputs are:
(words, segments, mask, permutation mask,
target mapping, masked tokens)
"""
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.RandomNormal(
mean=0., stddev=.02)
xlnet_base = get_xlnet_base(
model_config=model_config,
run_config=run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
return models.XLNetClassifier(
network=xlnet_base,
num_classes=num_labels,
dropout_rate=run_config.dropout,
summary_type='last',
initializer=initializer), xlnet_base
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_models
class XLNetModelsTest(tf.test.TestCase):
def setUp(self):
super(XLNetModelsTest, self).setUp()
self._xlnet_test_config = xlnet_config.XLNetConfig(
args_dict=dict(
n_layer=2,
d_model=4,
n_head=1,
d_head=2,
d_inner=4,
ff_activation='gelu',
untie_r=True,
n_token=32000))
self._run_config = xlnet_config.RunConfig(
is_training=True,
use_tpu=False,
dropout=0.0,
dropout_att=0.0,
init_method='normal',
init_range=0.1,
init_std=0.02,
mem_len=0,
reuse_len=4,
bi_data=False,
clamp_len=-1,
same_length=False)
def test_xlnet_base(self):
xlnet_base = xlnet_models.get_xlnet_base(
model_config=self._xlnet_test_config,
run_config=self._run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
def test_xlnet_classifier(self):
xlnet_classifier, xlnet_base = xlnet_models.classifier_model(
model_config=self._xlnet_test_config,
run_config=self._run_config,
num_labels=2)
self.assertIsInstance(xlnet_classifier, tf.keras.Model)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment