"...resnet50_tensorflow.git" did not exist on "edc6571226110e8bcc1e9e88672915ad9ca523f2"
Commit 15d1514b authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement XLNet Pretrainer model.

PiperOrigin-RevId: 343345144
parent 18aadc2b
...@@ -21,4 +21,5 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder ...@@ -21,4 +21,5 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import * from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetPretrainer
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
...@@ -23,6 +23,135 @@ from official.nlp.modeling import layers ...@@ -23,6 +23,135 @@ from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
class XLNetMaskedLM(tf.keras.layers.Layer):
"""XLNet pretraining head."""
def __init__(self,
vocab_size: int,
hidden_size: int,
initializer: str = 'glorot_uniform',
activation: str = 'gelu',
name=None,
**kwargs):
super().__init__(name=name, **kwargs)
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._initializer = initializer
self._activation = activation
def build(self, input_shape):
self.dense = tf.keras.layers.Dense(
units=self._hidden_size,
activation=self._activation,
kernel_initializer=self._initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super().build(input_shape)
def call(self,
sequence_data: tf.Tensor,
embedding_table: tf.Tensor):
lm_data = self.dense(sequence_data)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
return logits
def get_config(self) -> Mapping[str, Any]:
config = {
'vocab_size':
self._vocab_size,
'hidden_size':
self._hidden_size,
'initializer':
self._initializer
}
base_config = super(XLNetMaskedLM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetPretrainer(tf.keras.Model):
"""XLNet-based pretrainer.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
mlm_activation: The activation (if any) to use in the Masked LM network. If
None, then no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Defaults
to a Glorot uniform initializer.
"""
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
mlm_activation=None,
mlm_initializer='glorot_uniform',
name: str = None,
**kwargs):
super().__init__(name=name, **kwargs)
self._config = {
'network': network,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
}
self._network = network
self._hidden_size = network.get_config()['hidden_size']
self._vocab_size = network.get_config()['vocab_size']
self._activation = mlm_activation
self._initializer = mlm_initializer
self._masked_lm = XLNetMaskedLM(
vocab_size=self._vocab_size,
hidden_size=self._hidden_size,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
masked_tokens = inputs['masked_tokens']
permutation_mask = inputs['permutation_mask']
target_mapping = inputs['target_mapping']
state = inputs.get('state', None)
attention_output, state = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=None,
state=state,
permutation_mask=permutation_mask,
target_mapping=target_mapping,
masked_tokens=masked_tokens)
embedding_table = self._network.get_embedding_lookup_table()
mlm_outputs = self._masked_lm(
sequence_data=attention_output,
embedding_table=embedding_table)
return mlm_outputs, state
def get_config(self) -> Mapping[str, Any]:
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class XLNetClassifier(tf.keras.Model): class XLNetClassifier(tf.keras.Model):
"""Classifier model based on XLNet. """Classifier model based on XLNet.
......
...@@ -46,6 +46,104 @@ def _get_xlnet_base() -> tf.keras.layers.Layer: ...@@ -46,6 +46,104 @@ def _get_xlnet_base() -> tf.keras.layers.Layer:
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class XLNetMaskedLMTest(keras_parameterized.TestCase):
def test_xlnet_masked_lm_head(self):
hidden_size = 10
seq_length = 8
batch_size = 2
masked_lm = xlnet.XLNetMaskedLM(vocab_size=10,
hidden_size=hidden_size,
initializer='glorot_uniform')
sequence_data = np.random.uniform(size=(batch_size, seq_length))
embedding_table = np.random.uniform(size=(hidden_size, hidden_size))
mlm_output = masked_lm(sequence_data, embedding_table)
self.assertAllClose(mlm_output.shape, (batch_size, hidden_size))
@keras_parameterized.run_all_keras_modes
class XLNetPretrainerTest(keras_parameterized.TestCase):
def test_xlnet_trainer(self):
"""Validates that the Keras object can be created."""
seq_length = 4
num_predictions = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)
inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_mask'),
permutation_mask=tf.keras.layers.Input(
shape=(seq_length, seq_length,), dtype=tf.int32,
name='permutation_mask'),
target_mapping=tf.keras.layers.Input(
shape=(num_predictions, seq_length), dtype=tf.int32,
name='target_mapping'),
masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='masked_tokens'))
logits, _ = xlnet_trainer_model(inputs)
# [None, hidden_size, vocab_size]
expected_output_shape = [None, 4, 100]
self.assertAllEqual(expected_output_shape, logits.shape.as_list())
def test_xlnet_tensor_call(self):
"""Validates that the Keras object can be invoked."""
seq_length = 4
batch_size = 2
num_predictions = 2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('int32'),
target_mapping=np.random.randint(
10, size=(num_predictions, seq_length), dtype='int32'),
masked_tokens=np.random.randint(
10, size=sequence_shape, dtype='int32'))
xlnet_trainer_model(inputs)
def test_serialize_deserialize(self):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
# Create an XLNet trainer with the created network.
xlnet_trainer_model = xlnet.XLNetPretrainer(
network=xlnet_base,
mlm_activation='gelu',
mlm_initializer='random_normal')
# Create another XLNet trainer via serialization and deserialization.
config = xlnet_trainer_model.get_config()
new_xlnet_trainer_model = xlnet.XLNetPretrainer.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_xlnet_trainer_model.to_json()
# If serialization was successful, then the new config should match the old.
self.assertAllEqual(xlnet_trainer_model.get_config(),
new_xlnet_trainer_model.get_config())
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class XLNetClassifierTest(keras_parameterized.TestCase): class XLNetClassifierTest(keras_parameterized.TestCase):
...@@ -69,13 +167,12 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -69,13 +167,12 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
input_type_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'), shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'), shape=(seq_length,), dtype=tf.int32, name='input_mask'),
permutation_mask=tf.keras.layers.Input( permutation_mask=tf.keras.layers.Input(
shape=(seq_length, seq_length,), dtype=tf.float32, shape=(seq_length, seq_length,), dtype=tf.int32,
name='permutation_mask'), name='permutation_mask'),
masked_tokens=tf.keras.layers.Input( masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='masked_tokens')) shape=(seq_length,), dtype=tf.int32, name='masked_tokens'))
logits = xlnet_trainer_model(inputs) logits = xlnet_trainer_model(inputs)
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
...@@ -102,10 +199,11 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -102,10 +199,11 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
input_word_ids=np.random.randint( input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'), 10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'), input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'), input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
permutation_mask=np.random.randint( permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('float32'), 2, size=(batch_size, seq_length, seq_length)).astype('int32'),
masked_tokens=tf.random.uniform(shape=sequence_shape)) masked_tokens=np.random.randint(
10, size=sequence_shape, dtype='int32'))
xlnet_trainer_model(inputs) xlnet_trainer_model(inputs)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
...@@ -158,9 +256,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase): ...@@ -158,9 +256,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
input_type_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'), shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'), shape=(seq_length,), dtype=tf.int32, name='input_mask'),
paragraph_mask=tf.keras.layers.Input( paragraph_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='paragraph_mask'), shape=(seq_length,), dtype=tf.int32, name='paragraph_mask'),
class_index=tf.keras.layers.Input( class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'), shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input( start_positions=tf.keras.layers.Input(
...@@ -175,9 +273,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase): ...@@ -175,9 +273,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
input_word_ids=np.random.randint( input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'), 10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'), input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'), input_mask=np.random.randint(2, size=sequence_shape).astype('int32'),
paragraph_mask=np.random.randint( paragraph_mask=np.random.randint(
1, size=(sequence_shape)).astype('float32'), 1, size=(sequence_shape)).astype('int32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'), class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform( start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32)) shape=(batch_size,), maxval=5, dtype=tf.int32))
......
...@@ -242,7 +242,8 @@ def _compute_segment_matrix( ...@@ -242,7 +242,8 @@ def _compute_segment_matrix(
if segment_ids is None: if segment_ids is None:
return None return None
memory_padding = tf.zeros([batch_size, memory_length], dtype=tf.int32) memory_padding = tf.zeros([batch_size, memory_length],
dtype=segment_ids.dtype)
padded_segment_ids = tf.concat([memory_padding, segment_ids], 1) padded_segment_ids = tf.concat([memory_padding, segment_ids], 1)
# segment_ids: [B, S] # segment_ids: [B, S]
# padded_segment_ids: [B, S + M] # padded_segment_ids: [B, S + M]
...@@ -629,7 +630,12 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -629,7 +630,12 @@ class XLNetBase(tf.keras.layers.Layer):
"enabled. Please enable `two_stream` to enable two " "enabled. Please enable `two_stream` to enable two "
"stream attention.") "stream attention.")
dtype = input_mask.dtype if input_mask is not None else tf.float32 if input_mask is not None:
dtype = input_mask.dtype
elif permutation_mask is not None:
dtype = permutation_mask.dtype
else:
dtype = tf.int32
query_attention_mask, content_attention_mask = _compute_attention_mask( query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
permutation_mask=permutation_mask, permutation_mask=permutation_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment