Commit 801ac678 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339095008
parent b70019f0
...@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
vocab_size: int = 32000
num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16
head_size: int = 64
inner_size: int = 4096
inner_activation: str = "gelu"
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
attention_type: str = "bi"
bi_data: bool = False
tie_attention_biases: bool = False
memory_length: int = 0
same_length: bool = False
clamp_length: int = -1
reuse_length: int = 0
use_cls_mask: bool = False
embedding_width: int = 1024
initializer_range: float = 0.02
two_stream: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig): class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration.""" """Encoder configuration."""
...@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
ENCODER_CLS = { ENCODER_CLS = {
...@@ -151,6 +177,7 @@ ENCODER_CLS = { ...@@ -151,6 +177,7 @@ ENCODER_CLS = {
"mobilebert": networks.MobileBERTEncoder, "mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertEncoder, "albert": networks.AlbertEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder, "bigbird": bigbird_encoder.BigBirdEncoder,
"xlnet": networks.XLNetBase,
} }
...@@ -266,6 +293,29 @@ def build_encoder( ...@@ -266,6 +293,29 @@ def build_encoder(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size) embedding_width=encoder_cfg.embedding_size)
if encoder_type == "xlnet":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
head_size=encoder_cfg.head_size,
inner_size=encoder_cfg.inner_size,
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
attention_type=encoder_cfg.attention_type,
bi_data=encoder_cfg.bi_data,
two_stream=encoder_cfg.two_stream,
tie_attention_biases=encoder_cfg.tie_attention_biases,
memory_length=encoder_cfg.memory_length,
clamp_length=encoder_cfg.clamp_length,
reuse_length=encoder_cfg.reuse_length,
inner_activation=encoder_cfg.inner_activation,
use_cls_mask=encoder_cfg.use_cls_mask,
embedding_width=encoder_cfg.embedding_width,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return encoder_cls( return encoder_cls(
......
...@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims): ...@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
def _rel_shift(x, klen=-1): def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score.""" """Performs relative shift to form the relative attention score."""
...@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
**Note: This layer is currently experimental. **Note: This layer is currently experimental.
Attributes: Attributes:
num_heads: The number of attention heads. kernel_initializer: The kernel initializer. Defaults to variance_scaling.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args: Call args:
query: Query `Tensor` of shape `[B, T, dim]`. query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
...@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.multiply( attention_scores = tf.multiply(
attention_sum, 1.0 / math.sqrt(float(self._key_dim))) attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
# `attention_scores`: `[B, N, S, S + M]` attention_scores = self._masked_softmax(attention_scores, attention_mask)
if attention_mask is not None:
attention_scores += (_large_compatible_negative(attention_scores.dtype)
* attention_mask)
attention_scores = tf.nn.softmax(attention_scores, 3)
attention_output = self._dropout_layer(attention_scores) attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
......
...@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer):
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout inner_dropout: Dropout probability for the inner dropout
layer. layer.
""" """
def __init__(self, def __init__(self,
......
...@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model): ...@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model):
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237). Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Arguments: Arguments:
network: An XLNet/Transformer-XL based network. This network should output a network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors. sequence output and list of `state` tensors.
...@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model):
raise ValueError('Invalid summary type provided: %s.' % summary_type) raise ValueError('Invalid summary type provided: %s.' % summary_type)
self.classifier = layers.ClassificationHead( self.classifier = layers.ClassificationHead(
inner_dim=network.get_config()['inner_size'], inner_dim=network.get_config()['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
...@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model): ...@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model):
name='sentence_prediction') name='sentence_prediction')
def call(self, inputs: Mapping[str, Any]): def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids'] input_ids = inputs['input_word_ids']
segment_ids = inputs['segment_ids'] segment_ids = inputs['input_type_ids']
input_mask = inputs['input_mask'] input_mask = tf.cast(inputs['input_mask'], tf.float32)
state = inputs.get('mems', None) state = inputs.get('mems', None)
attention_output, new_states = self._network( attention_output, _ = self._network(
input_ids=input_ids, input_ids=input_ids,
segment_ids=segment_ids, segment_ids=segment_ids,
input_mask=input_mask, input_mask=input_mask,
...@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model):
logits = self.classifier(attention_output) logits = self.classifier(attention_output)
return logits, new_states return logits
def get_config(self): def get_config(self):
return self._config return self._config
...@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model): ...@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model):
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
@property
def checkpoint_items(self):
items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model): class XLNetSpanLabeler(tf.keras.Model):
......
...@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
summary_type='last', summary_type='last',
dropout_rate=0.1) dropout_rate=0.1)
inputs = dict( inputs = dict(
input_ids=tf.keras.layers.Input( input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'), shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'), shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'), shape=(seq_length,), dtype=tf.float32, name='input_mask'),
permutation_mask=tf.keras.layers.Input( permutation_mask=tf.keras.layers.Input(
...@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
masked_tokens=tf.keras.layers.Input( masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='masked_tokens')) shape=(seq_length,), dtype=tf.float32, name='masked_tokens'))
logits, _ = xlnet_trainer_model(inputs) logits = xlnet_trainer_model(inputs)
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, logits.shape.as_list()) self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
...@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
sequence_shape = (batch_size, seq_length) sequence_shape = (batch_size, seq_length)
inputs = dict( inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'), input_word_ids=np.random.randint(
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'), 10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'), input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
permutation_mask=np.random.randint( permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('float32'), 2, size=(batch_size, seq_length, seq_length)).astype('float32'),
......
...@@ -49,6 +49,9 @@ def _create_causal_attention_mask( ...@@ -49,6 +49,9 @@ def _create_causal_attention_mask(
concatenating 0s (representing memory positions) with a strictly upper concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s. triangular matrix of 1s.
We then flip the matrix values in order to match the representation where
real values are 1s.
Arguments: Arguments:
seq_length: int, The length of each sequence. seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks. memory_length: int, The length of memory blocks.
...@@ -59,10 +62,10 @@ def _create_causal_attention_mask( ...@@ -59,10 +62,10 @@ def _create_causal_attention_mask(
A unidirectional attention mask of shape A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.: `[seq_length, seq_length + memory_length]`. E.g.:
[[0. 0. 0. 1. 1. 1.] [[1. 1. 1. 0. 0. 0.]
[0. 0. 0. 0. 1. 1.] [1. 1. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 1.] [1. 1. 1. 1. 1. 0.]
[0. 0. 0. 0. 0. 0.]] [1. 1. 1. 1. 1. 1.]]
""" """
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype) ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype)
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1) upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1)
...@@ -78,7 +81,32 @@ def _create_causal_attention_mask( ...@@ -78,7 +81,32 @@ def _create_causal_attention_mask(
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular, [causal_attention_mask[:, :seq_length] + strictly_lower_triangular,
causal_attention_mask[:, seq_length:]], 1) causal_attention_mask[:, seq_length:]], 1)
return causal_attention_mask return 1 - causal_attention_mask
def _combine_masks(mask1, mask2, dtype, how="and"):
"""Combines two masks.
Use "and" if trying to combine two existing masks.
Use "or" if trying to flip a few positions to "real".
Args:
mask1: tf.Tensor, input mask 1
mask2: tf.Tensor, input mask 2
dtype: tf.dtype
how: Which logical operation should run.
Returns:
The combined input masks.
"""
if how == "and":
operator = tf.math.logical_and
else:
operator = tf.math.logical_or
return tf.cast(operator(
tf.cast(mask1, tf.bool),
tf.cast(mask2, tf.bool)), dtype=dtype)
def _compute_attention_mask( def _compute_attention_mask(
...@@ -140,8 +168,7 @@ def _compute_attention_mask( ...@@ -140,8 +168,7 @@ def _compute_attention_mask(
# input_mask: [B, S] # input_mask: [B, S]
# permutation_mask: [B, S, S] # permutation_mask: [B, S, S]
if input_mask is not None and permutation_mask is not None: if input_mask is not None and permutation_mask is not None:
data_mask = input_mask[:, None, :] + permutation_mask data_mask = _combine_masks(input_mask[:, None, :], permutation_mask, dtype)
elif input_mask is not None and permutation_mask is None: elif input_mask is not None and permutation_mask is None:
data_mask = input_mask[:, None, :] data_mask = input_mask[:, None, :]
elif input_mask is None and permutation_mask is not None: elif input_mask is None and permutation_mask is not None:
...@@ -153,28 +180,28 @@ def _compute_attention_mask( ...@@ -153,28 +180,28 @@ def _compute_attention_mask(
if data_mask is not None: if data_mask is not None:
# All positions within state can be attended to. # All positions within state can be attended to.
state_mask = tf.zeros([batch_size, tf.shape(data_mask)[1], memory_length], state_mask = tf.ones([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype) dtype=dtype)
# state_mask: [B, 1, M] or [B, S, M] # state_mask: [B, 1, M] or [B, S, M]
data_mask = tf.concat([state_mask, data_mask], 2) data_mask = tf.concat([state_mask, data_mask], 2)
# data_mask: [B, 1, S + M] or [B, S, S + M] # data_mask: [B, 1, S + M] or [B, S, S + M]
if attention_type == "uni": if attention_type == "uni":
attention_mask = causal_attention_mask + data_mask[:, None, :, :] attention_mask = _combine_masks(causal_attention_mask,
data_mask[:, None, :, :],
dtype=dtype)
else: else:
attention_mask = data_mask[:, None, :, :] attention_mask = data_mask[:, None, :, :]
# Construct the content attention mask.
if attention_mask is not None: if attention_mask is not None:
attention_mask = tf.cast(attention_mask > 0, dtype=dtype) # Construct the content attention mask.
# This ensures that the mask allows the model to attend to positions in
non_tgt_mask = -tf.eye(seq_length, dtype=dtype) # content positions (e.g. the content diagonal).
non_tgt_mask = tf.concat( non_target_mask = tf.concat(
[tf.zeros([seq_length, memory_length], dtype=dtype), [tf.zeros([seq_length, memory_length], dtype=dtype),
non_tgt_mask], axis=-1) tf.eye(seq_length, dtype=dtype)], axis=-1)
content_attention_mask = tf.cast( content_attention_mask = _combine_masks(
(attention_mask + non_tgt_mask[None, None, :, :]) > 0, attention_mask, non_target_mask, how="or", dtype=dtype)
dtype=dtype)
else: else:
content_attention_mask = None content_attention_mask = None
......
...@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase): ...@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length=seq_length, seq_length=seq_length,
memory_length=memory_length) memory_length=memory_length)
expected_output = np.array([[0, 1, 1], expected_output = np.array([[1, 0, 0],
[0, 0, 1], [1, 1, 0],
[0, 0, 0]]) [1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output) self.assertAllClose(causal_attention_mask, expected_output)
def test_casual_attention_mask_with_memory(self): def test_casual_attention_mask_with_memory(self):
...@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase): ...@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length=seq_length, seq_length=seq_length,
memory_length=memory_length) memory_length=memory_length)
expected_output = np.array([[0, 0, 0, 1, 1], expected_output = np.array([[1, 1, 1, 0, 0],
[0, 0, 0, 0, 1], [1, 1, 1, 1, 0],
[0, 0, 0, 0, 0]]) [1, 1, 1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output) self.assertAllClose(causal_attention_mask, expected_output)
def test_causal_attention_mask_with_same_length(self): def test_causal_attention_mask_with_same_length(self):
...@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase): ...@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
memory_length=memory_length, memory_length=memory_length,
same_length=True) same_length=True)
expected_output = np.array([[0, 0, 0, 1, 1], expected_output = np.array([[1, 1, 1, 0, 0],
[1, 0, 0, 0, 1], [0, 1, 1, 1, 0],
[1, 1, 0, 0, 0]]) [0, 0, 1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output) self.assertAllClose(causal_attention_mask, expected_output)
...@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase): ...@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
memory_length = 0 memory_length = 0
input_mask = np.array([[0, 0, 1, 1]]) input_mask = np.array([[1, 1, 0, 0]])
permutation_mask = None permutation_mask = None
expected_query_mask = input_mask[None, None, :, :] expected_query_mask = input_mask[None, None, :, :]
expected_content_mask = np.array([[[ expected_content_mask = np.array([[[
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 0, 1], [1, 1, 1, 0],
[0, 0, 1, 0]]]]) [1, 1, 0, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask( query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
...@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase): ...@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase):
input_mask = None input_mask = None
permutation_mask = np.array([ permutation_mask = np.array([
[[0, 1], [[1, 0],
[0, 1]], [1, 0]],
]) ])
expected_query_mask = permutation_mask[:, None, :, :] expected_query_mask = permutation_mask[:, None, :, :]
expected_content_mask = np.array([[[ expected_content_mask = np.array([[[
[0, 1], [1, 0],
[0, 0]]]]) [1, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask( query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
...@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase): ...@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
memory_length = 0 memory_length = 0
input_mask = np.array([[0, 0, 1, 1]]) input_mask = np.array([[1, 1, 0, 0]])
permutation_mask = np.array([[ permutation_mask = np.array([[
[1, 0, 0, 0], [0, 1, 1, 1],
[0, 1, 0, 0], [1, 0, 1, 1],
[0, 0, 1, 0], [1, 1, 0, 1],
[0, 0, 0, 1], [1, 1, 1, 0],
]]) ]])
expected_query_mask = np.array([[[ expected_query_mask = np.array([[[
[1, 0, 1, 1], [0, 1, 0, 0],
[0, 1, 1, 1], [1, 0, 0, 0],
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 1, 1]]]]) [1, 1, 0, 0]]]])
expected_content_mask = np.array([[[ expected_content_mask = np.array([[[
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 0, 1], [1, 1, 1, 0],
[0, 0, 1, 0]]]]) [1, 1, 0, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask( query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
permutation_mask=permutation_mask, permutation_mask=permutation_mask,
...@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase): ...@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
memory_length = 0 memory_length = 0
input_mask = np.array([[0, 0, 0, 1]]) input_mask = np.array([[1, 1, 1, 0]])
permutation_mask = np.array([[ permutation_mask = np.array([[
[1, 0, 0, 0], [0, 1, 1, 1],
[0, 1, 0, 0], [1, 0, 1, 1],
[0, 0, 1, 0], [1, 1, 0, 1],
[0, 0, 0, 1], [1, 1, 1, 0],
]]) ]])
expected_query_mask = np.array([[[ expected_query_mask = np.array([[[
[1, 1, 1, 1], [0, 0, 0, 0],
[0, 1, 1, 1], [1, 0, 0, 0],
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 0, 1]]]]) [1, 1, 1, 0]]]])
expected_content_mask = np.array([[[ expected_content_mask = np.array([[[
[0, 1, 1, 1], [1, 0, 0, 0],
[0, 0, 1, 1], [1, 1, 0, 0],
[0, 0, 0, 1], [1, 1, 1, 0],
[0, 0, 0, 0]]]]) [1, 1, 1, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask( query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
permutation_mask=permutation_mask, permutation_mask=permutation_mask,
......
...@@ -81,7 +81,13 @@ class SentencePredictionTask(base_task.Task): ...@@ -81,7 +81,13 @@ class SentencePredictionTask(base_task.Task):
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only support bert-style sentence prediction finetuning. if self.task_config.model.encoder.type == 'xlnet':
return models.XLNetClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
else:
return models.BertClassifier( return models.BertClassifier(
network=encoder_network, network=encoder_network,
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Keras layers of XLNet model in TF 2.0.""" """Keras layers of XLNet model in TF 2.0."""
import copy import copy
import warnings
import tensorflow as tf import tensorflow as tf
...@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
""" """
super(TransformerXLModel, self).__init__(**kwargs) super(TransformerXLModel, self).__init__(**kwargs)
warnings.warn(
"`TransformerXLModel` is deprecated, please use `XLNetBase` instead",
DeprecationWarning, stacklevel=2)
self.n_token = n_token self.n_token = n_token
self.initializer = initializer self.initializer = initializer
...@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model):
""" """
def __init__(self, use_proj, xlnet_config, run_config, **kwargs): def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True,
**kwargs):
super(PretrainingXLNetModel, self).__init__(**kwargs) super(PretrainingXLNetModel, self).__init__(**kwargs)
self.run_config = run_config self.run_config = run_config
self.initializer = _get_initializer(run_config) self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config) self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase( self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token, vocab_size=self.xlnet_config.n_token,
...@@ -788,6 +794,9 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -788,6 +794,9 @@ class PretrainingXLNetModel(tf.keras.Model):
input_ids = features["input_ids"] input_ids = features["input_ids"]
masked_tokens = features["input_q"] masked_tokens = features["input_q"]
seg_ids = features["seg_id"] seg_ids = features["seg_id"]
if self._use_legacy_mask:
perm_mask = 1 - features["perm_mask"]
else:
perm_mask = features["perm_mask"] perm_mask = features["perm_mask"]
target_mapping = features["target_mapping"] target_mapping = features["target_mapping"]
...@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model):
""" """
def __init__(self, xlnet_config, run_config, n_class, summary_type, **kwargs): def __init__(self, xlnet_config, run_config, n_class, summary_type,
use_legacy_mask=True, **kwargs):
super(ClassificationXLNetModel, self).__init__(**kwargs) super(ClassificationXLNetModel, self).__init__(**kwargs)
warnings.warn(
"`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
"instead.", DeprecationWarning, stacklevel=2)
self.run_config = run_config self.run_config = run_config
self.initializer = _get_initializer(run_config) self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config) self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase( self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token, vocab_size=self.xlnet_config.n_token,
...@@ -870,6 +884,9 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -870,6 +884,9 @@ class ClassificationXLNetModel(tf.keras.Model):
input_ids = features["input_ids"] input_ids = features["input_ids"]
segment_ids = features["segment_ids"] segment_ids = features["segment_ids"]
if self._use_legacy_mask:
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"] input_mask = features["input_mask"]
label = tf.reshape(features["label_ids"], [batch_size_per_core]) label = tf.reshape(features["label_ids"], [batch_size_per_core])
...@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model): ...@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model):
""" """
def __init__(self, xlnet_config, run_config, start_n_top, end_n_top, def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
**kwargs): use_legacy_mask=True, **kwargs):
super(QAXLNetModel, self).__init__(**kwargs) super(QAXLNetModel, self).__init__(**kwargs)
warnings.warn(
"`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead.",
DeprecationWarning, stacklevel=2)
self.run_config = run_config self.run_config = run_config
self.initializer = _get_initializer(run_config) self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config) self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase( self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token, vocab_size=self.xlnet_config.n_token,
...@@ -1108,6 +1129,9 @@ class QAXLNetModel(tf.keras.Model): ...@@ -1108,6 +1129,9 @@ class QAXLNetModel(tf.keras.Model):
input_ids = features["input_ids"] input_ids = features["input_ids"]
segment_ids = features["segment_ids"] segment_ids = features["segment_ids"]
if self._use_legacy_mask:
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"] input_mask = features["input_mask"]
cls_index = tf.reshape(features["cls_index"], [-1]) cls_index = tf.reshape(features["cls_index"], [-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment