"src/vscode:/vscode.git/clone" did not exist on "268ebcb015d8816257f24e3cd930fb48964314b4"
Commit 44a5367a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement XLNet Classifier model.

PiperOrigin-RevId: 335069348
parent baacb20d
...@@ -19,3 +19,4 @@ from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler ...@@ -19,3 +19,4 @@ from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.xlnet import XLNetClassifier
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet cls-token classifier."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetClassifier(tf.keras.Model):
"""Classifier model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
num_classes: int,
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
**kwargs):
super().__init__(**kwargs)
self._network = network
self._initializer = initializer
self._summary_type = summary_type
self._num_classes = num_classes
self._config = {
'network': network,
'initializer': initializer,
'num_classes': num_classes,
'summary_type': summary_type,
'dropout_rate': dropout_rate,
}
if summary_type == 'last':
cls_token_idx = -1
elif summary_type == 'first':
cls_token_idx = 0
else:
raise ValueError('Invalid summary type provided: %s.' % summary_type)
self.classifier = layers.ClassificationHead(
inner_dim=network.get_config()['inner_size'],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx,
name='sentence_prediction')
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids']
segment_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
state = inputs.get('mems', None)
attention_output, new_states = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
state=state)
logits = self.classifier(attention_output)
return logits, new_states
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLNet classifier network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import xlnet
def _get_xlnet_base() -> tf.keras.layers.Layer:
"""Returns a trivial base XLNet model."""
return networks.XLNetBase(
vocab_size=100,
num_layers=2,
hidden_size=4,
num_attention_heads=2,
head_size=2,
inner_size=2,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type='bi',
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
reuse_length=0,
inner_activation='relu')
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class XLNetClassifierTest(keras_parameterized.TestCase):
def test_xlnet_trainer(self):
"""Validate that the Keras object can be created."""
num_classes = 2
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=num_classes,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
inputs = dict(
input_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'),
permutation_mask=tf.keras.layers.Input(
shape=(seq_length, seq_length,), dtype=tf.float32,
name='permutation_mask'),
masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='masked_tokens'))
logits, _ = xlnet_trainer_model(inputs)
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
@parameterized.parameters(1, 2)
def test_xlnet_tensor_call(self, num_classes):
"""Validates that the Keras object can be invoked."""
seq_length = 4
batch_size = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=num_classes,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('float32'),
masked_tokens=tf.random.uniform(shape=sequence_shape))
xlnet_trainer_model(inputs)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetClassifier(
network=xlnet_base,
num_classes=2,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
summary_type='last',
dropout_rate=0.1)
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetClassifier.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
if __name__ == '__main__':
tf.test.main()
...@@ -586,8 +586,11 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -586,8 +586,11 @@ class XLNetBase(tf.keras.layers.Layer):
masked_tokens = inputs["masked_tokens"] masked_tokens = inputs["masked_tokens"]
batch_size = tf.shape(input_ids)[0] batch_size = tf.shape(input_ids)[0]
seq_length = input_ids.shape.as_list()[1] seq_length = tf.shape(input_ids)[1]
memory_length = state[0].shape.as_list()[1] if state is not None else 0 if state is not None:
memory_length = tf.shape(state[0])[1]
else:
memory_length = 0
total_length = memory_length + seq_length total_length = memory_length + seq_length
if self._two_stream and masked_tokens is None: if self._two_stream and masked_tokens is None:
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models that are compatible with TF 2.x."""
import tensorflow as tf
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.xlnet import xlnet_config
def _get_initializer(
initialization_method: str,
initialization_range: float,
initialization_std: float) -> tf.keras.initializers.Initializer:
"""Gets variable initializer."""
if initialization_method == 'uniform':
initializer = tf.keras.initializers.RandomUniform(
minval=-initialization_range, maxval=initialization_range)
elif initialization_method == 'normal':
initializer = tf.keras.initializers.RandomNormal(stddev=initialization_std)
else:
raise ValueError('Initializer {} not supported'.format(
initialization_method))
return initializer
def get_xlnet_base(model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
attention_type: str,
two_stream: bool,
use_cls_mask: bool) -> tf.keras.Model:
"""Gets an 'XLNetBase' object.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
attention_type: the attention type for the base XLNet model, "uni" or "bi".
two_stream: whether or not to use two strema attention.
use_cls_mask: whether or not cls mask is included in the input sequences.
Returns:
An XLNetBase object.
"""
initializer = _get_initializer(initialization_method=run_config.init_method,
initialization_range=run_config.init_range,
initialization_std=run_config.init_std)
kwargs = dict(
vocab_size=model_config.n_token,
num_layers=model_config.n_layer,
hidden_size=model_config.d_model,
num_attention_heads=model_config.n_head,
head_size=model_config.d_head,
inner_size=model_config.d_inner,
dropout_rate=run_config.dropout,
attention_dropout_rate=run_config.dropout_att,
attention_type=attention_type,
bi_data=run_config.bi_data,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=not model_config.untie_r,
memory_length=run_config.mem_len,
clamp_length=run_config.clamp_len,
reuse_length=run_config.reuse_len,
inner_activation=model_config.ff_activation,
use_cls_mask=use_cls_mask)
return networks.XLNetBase(**kwargs)
def classifier_model(
model_config: xlnet_config.XLNetConfig,
run_config: xlnet_config.RunConfig,
num_labels: int,
final_layer_initializer: tf.keras.initializers.Initializer = None
) -> tf.keras.Model:
"""Returns a TF2 Keras XLNet classifier model.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
num_labels: integer, the number of classes.
final_layer_initializer: Initializer for final dense layer. If `None`, then
it defaults to the one specified in `run_config`.
Returns:
Combined prediction model inputs -> (one-hot labels)
XLNet sub-model inputs -> (xlnet_outputs)
where inputs are:
(words, segments, mask, permutation mask,
target mapping, masked tokens)
"""
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.RandomNormal(
mean=0., stddev=.02)
xlnet_base = get_xlnet_base(
model_config=model_config,
run_config=run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
return models.XLNetClassifier(
network=xlnet_base,
num_classes=num_labels,
dropout_rate=run_config.dropout,
summary_type='last',
initializer=initializer), xlnet_base
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_models
class XLNetModelsTest(tf.test.TestCase):
def setUp(self):
super(XLNetModelsTest, self).setUp()
self._xlnet_test_config = xlnet_config.XLNetConfig(
args_dict=dict(
n_layer=2,
d_model=4,
n_head=1,
d_head=2,
d_inner=4,
ff_activation='gelu',
untie_r=True,
n_token=32000))
self._run_config = xlnet_config.RunConfig(
is_training=True,
use_tpu=False,
dropout=0.0,
dropout_att=0.0,
init_method='normal',
init_range=0.1,
init_std=0.02,
mem_len=0,
reuse_len=4,
bi_data=False,
clamp_len=-1,
same_length=False)
def test_xlnet_base(self):
xlnet_base = xlnet_models.get_xlnet_base(
model_config=self._xlnet_test_config,
run_config=self._run_config,
attention_type='bi',
two_stream=False,
use_cls_mask=False)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
def test_xlnet_classifier(self):
xlnet_classifier, xlnet_base = xlnet_models.classifier_model(
model_config=self._xlnet_test_config,
run_config=self._run_config,
num_labels=2)
self.assertIsInstance(xlnet_classifier, tf.keras.Model)
self.assertIsInstance(xlnet_base, tf.keras.layers.Layer)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment