Commit 801ac678 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339095008
parent b70019f0
......@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
vocab_size: int = 32000
num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16
head_size: int = 64
inner_size: int = 4096
inner_activation: str = "gelu"
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
attention_type: str = "bi"
bi_data: bool = False
tie_attention_biases: bool = False
memory_length: int = 0
same_length: bool = False
clamp_length: int = -1
reuse_length: int = 0
use_cls_mask: bool = False
embedding_width: int = 1024
initializer_range: float = 0.02
two_stream: bool = False
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
......@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
ENCODER_CLS = {
......@@ -151,6 +177,7 @@ ENCODER_CLS = {
"mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder,
"xlnet": networks.XLNetBase,
}
......@@ -266,6 +293,29 @@ def build_encoder(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size)
if encoder_type == "xlnet":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
head_size=encoder_cfg.head_size,
inner_size=encoder_cfg.inner_size,
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
attention_type=encoder_cfg.attention_type,
bi_data=encoder_cfg.bi_data,
two_stream=encoder_cfg.two_stream,
tie_attention_biases=encoder_cfg.tie_attention_biases,
memory_length=encoder_cfg.memory_length,
clamp_length=encoder_cfg.clamp_length,
reuse_length=encoder_cfg.reuse_length,
inner_activation=encoder_cfg.inner_activation,
use_cls_mask=encoder_cfg.use_cls_mask,
embedding_width=encoder_cfg.embedding_width,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return encoder_cls(
......
......@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
......@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
**Note: This layer is currently experimental.
Attributes:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_initializer: The kernel initializer. Defaults to variance_scaling.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
......@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.multiply(
attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
# `attention_scores`: `[B, N, S, S + M]`
if attention_mask is not None:
attention_scores += (_large_compatible_negative(attention_scores.dtype)
* attention_mask)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores = tf.nn.softmax(attention_scores, 3)
attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum(self._combine_equation,
......
......@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer):
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def __init__(self,
......
......@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model):
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
......@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model):
raise ValueError('Invalid summary type provided: %s.' % summary_type)
self.classifier = layers.ClassificationHead(
inner_dim=network.get_config()['inner_size'],
inner_dim=network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
......@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model):
name='sentence_prediction')
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids']
segment_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_mask = tf.cast(inputs['input_mask'], tf.float32)
state = inputs.get('mems', None)
attention_output, new_states = self._network(
attention_output, _ = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
......@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model):
logits = self.classifier(attention_output)
return logits, new_states
return logits
def get_config(self):
return self._config
......@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model):
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
@tf.keras.utils.register_keras_serializable(package='Text')
class XLNetSpanLabeler(tf.keras.Model):
......
......@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
summary_type='last',
dropout_rate=0.1)
inputs = dict(
input_ids=tf.keras.layers.Input(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'),
permutation_mask=tf.keras.layers.Input(
......@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
masked_tokens=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='masked_tokens'))
logits, _ = xlnet_trainer_model(inputs)
logits = xlnet_trainer_model(inputs)
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
......@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
permutation_mask=np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype('float32'),
......
......@@ -49,6 +49,9 @@ def _create_causal_attention_mask(
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
We then flip the matrix values in order to match the representation where
real values are 1s.
Arguments:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
......@@ -59,10 +62,10 @@ def _create_causal_attention_mask(
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
[[1. 1. 1. 0. 0. 0.]
[1. 1. 1. 1. 0. 0.]
[1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 1.]]
"""
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype)
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1)
......@@ -78,7 +81,32 @@ def _create_causal_attention_mask(
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular,
causal_attention_mask[:, seq_length:]], 1)
return causal_attention_mask
return 1 - causal_attention_mask
def _combine_masks(mask1, mask2, dtype, how="and"):
"""Combines two masks.
Use "and" if trying to combine two existing masks.
Use "or" if trying to flip a few positions to "real".
Args:
mask1: tf.Tensor, input mask 1
mask2: tf.Tensor, input mask 2
dtype: tf.dtype
how: Which logical operation should run.
Returns:
The combined input masks.
"""
if how == "and":
operator = tf.math.logical_and
else:
operator = tf.math.logical_or
return tf.cast(operator(
tf.cast(mask1, tf.bool),
tf.cast(mask2, tf.bool)), dtype=dtype)
def _compute_attention_mask(
......@@ -140,8 +168,7 @@ def _compute_attention_mask(
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if input_mask is not None and permutation_mask is not None:
data_mask = input_mask[:, None, :] + permutation_mask
data_mask = _combine_masks(input_mask[:, None, :], permutation_mask, dtype)
elif input_mask is not None and permutation_mask is None:
data_mask = input_mask[:, None, :]
elif input_mask is None and permutation_mask is not None:
......@@ -153,28 +180,28 @@ def _compute_attention_mask(
if data_mask is not None:
# All positions within state can be attended to.
state_mask = tf.zeros([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype)
state_mask = tf.ones([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype)
# state_mask: [B, 1, M] or [B, S, M]
data_mask = tf.concat([state_mask, data_mask], 2)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if attention_type == "uni":
attention_mask = causal_attention_mask + data_mask[:, None, :, :]
attention_mask = _combine_masks(causal_attention_mask,
data_mask[:, None, :, :],
dtype=dtype)
else:
attention_mask = data_mask[:, None, :, :]
# Construct the content attention mask.
if attention_mask is not None:
attention_mask = tf.cast(attention_mask > 0, dtype=dtype)
non_tgt_mask = -tf.eye(seq_length, dtype=dtype)
non_tgt_mask = tf.concat(
# Construct the content attention mask.
# This ensures that the mask allows the model to attend to positions in
# content positions (e.g. the content diagonal).
non_target_mask = tf.concat(
[tf.zeros([seq_length, memory_length], dtype=dtype),
non_tgt_mask], axis=-1)
content_attention_mask = tf.cast(
(attention_mask + non_tgt_mask[None, None, :, :]) > 0,
dtype=dtype)
tf.eye(seq_length, dtype=dtype)], axis=-1)
content_attention_mask = _combine_masks(
attention_mask, non_target_mask, how="or", dtype=dtype)
else:
content_attention_mask = None
......
......@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 1, 1],
[0, 0, 1],
[0, 0, 0]])
expected_output = np.array([[1, 0, 0],
[1, 1, 0],
[1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_casual_attention_mask_with_memory(self):
......@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]])
expected_output = np.array([[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_causal_attention_mask_with_same_length(self):
......@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
memory_length=memory_length,
same_length=True)
expected_output = np.array([[0, 0, 0, 1, 1],
[1, 0, 0, 0, 1],
[1, 1, 0, 0, 0]])
expected_output = np.array([[1, 1, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1]])
self.assertAllClose(causal_attention_mask, expected_output)
......@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
input_mask = np.array([[1, 1, 0, 0]])
permutation_mask = None
expected_query_mask = input_mask[None, None, :, :]
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 0, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
......@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase):
input_mask = None
permutation_mask = np.array([
[[0, 1],
[0, 1]],
[[1, 0],
[1, 0]],
])
expected_query_mask = permutation_mask[:, None, :, :]
expected_content_mask = np.array([[[
[0, 1],
[0, 0]]]])
[1, 0],
[1, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
......@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
input_mask = np.array([[1, 1, 0, 0]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 1, 1],
[1, 0, 1, 1],
[1, 1, 0, 1],
[1, 1, 1, 0],
]])
expected_query_mask = np.array([[[
[1, 0, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 1, 1]]]])
[0, 1, 0, 0],
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0]]]])
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 0, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
......@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 0, 1]])
input_mask = np.array([[1, 1, 1, 0]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 1, 1],
[1, 0, 1, 1],
[1, 1, 0, 1],
[1, 1, 1, 0],
]])
expected_query_mask = np.array([[[
[1, 1, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1]]]])
[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]]])
expected_content_mask = np.array([[[
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]]]])
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
......
......@@ -81,13 +81,19 @@ class SentencePredictionTask(base_task.Task):
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only support bert-style sentence prediction finetuning.
return models.BertClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
if self.task_config.model.encoder.type == 'xlnet':
return models.XLNetClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
else:
return models.BertClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
if self.task_config.model.num_classes == 1:
......
......@@ -15,6 +15,7 @@
"""Keras layers of XLNet model in TF 2.0."""
import copy
import warnings
import tensorflow as tf
......@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
"""
super(TransformerXLModel, self).__init__(**kwargs)
warnings.warn(
"`TransformerXLModel` is deprecated, please use `XLNetBase` instead",
DeprecationWarning, stacklevel=2)
self.n_token = n_token
self.initializer = initializer
......@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model):
"""
def __init__(self, use_proj, xlnet_config, run_config, **kwargs):
def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True,
**kwargs):
super(PretrainingXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token,
......@@ -788,7 +794,10 @@ class PretrainingXLNetModel(tf.keras.Model):
input_ids = features["input_ids"]
masked_tokens = features["input_q"]
seg_ids = features["seg_id"]
perm_mask = features["perm_mask"]
if self._use_legacy_mask:
perm_mask = 1 - features["perm_mask"]
else:
perm_mask = features["perm_mask"]
target_mapping = features["target_mapping"]
# target for LM loss
......@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model):
"""
def __init__(self, xlnet_config, run_config, n_class, summary_type, **kwargs):
def __init__(self, xlnet_config, run_config, n_class, summary_type,
use_legacy_mask=True, **kwargs):
super(ClassificationXLNetModel, self).__init__(**kwargs)
warnings.warn(
"`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
"instead.", DeprecationWarning, stacklevel=2)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token,
......@@ -870,7 +884,10 @@ class ClassificationXLNetModel(tf.keras.Model):
input_ids = features["input_ids"]
segment_ids = features["segment_ids"]
input_mask = features["input_mask"]
if self._use_legacy_mask:
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"]
label = tf.reshape(features["label_ids"], [batch_size_per_core])
......@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model):
"""
def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
**kwargs):
use_legacy_mask=True, **kwargs):
super(QAXLNetModel, self).__init__(**kwargs)
warnings.warn(
"`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead.",
DeprecationWarning, stacklevel=2)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self._use_legacy_mask = use_legacy_mask
self.xlnet_model = networks.XLNetBase(
vocab_size=self.xlnet_config.n_token,
......@@ -1108,7 +1129,10 @@ class QAXLNetModel(tf.keras.Model):
input_ids = features["input_ids"]
segment_ids = features["segment_ids"]
input_mask = features["input_mask"]
if self._use_legacy_mask:
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"]
cls_index = tf.reshape(features["cls_index"], [-1])
p_mask = features["p_mask"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment