"mmdet3d/datasets/transforms/transforms_3d.py" did not exist on "7fda1f661fb23c559ad40d346315e9d4f39181f7"
Commit 0864d66e authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement XLNetBase. This is the underlying network shared between XLNet models.

PiperOrigin-RevId: 333183470
parent 4c226604
......@@ -26,3 +26,4 @@ to 1) head.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet: Generalized Autoregressive Pretraining for Language Understanding"(https://arxiv.org/abs/1906.08237). It includes embedding lookups, relative position encodings, mask computations, segment matrix computations and Transformer XL layers using one or two stream relative self-attention.
......@@ -19,5 +19,6 @@ from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder
from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.xlnet_base import XLNetBase
# Backward compatibility. The modules are deprecated.
TransformerEncoder = BertEncoder
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based XLNet Model."""
from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer_xl
_SEG_ID_CLS = 2
def _create_causal_attention_mask(
seq_length,
memory_length,
dtype=tf.float32,
same_length=False):
"""Creates a causal attention mask with a single-sided context.
When applying the attention mask in `MultiHeadRelativeAttention`, the
attention scores are of shape `[(batch dimensions), S, S + M]`, where:
- S = sequence length.
- M = memory length.
In a simple case where S = 2, M = 1, here is a simple illustration of the
`attention_scores` matrix, where `a` represents an attention function:
token_0 [[a(token_0, mem_0) a(token_0, token_0) a(token_0, token_1)],
token_1 [a(token_1, mem_0) a(token_1, token_0) a(token_1, token_1)]]
mem_0 token_0 token_1
For uni-directional attention, we want to mask out values in the attention
scores that represent a(token_i, token_j) where j > i. We can achieve this by
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
Arguments:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
dtype: dtype of the mask.
same_length: bool, whether to use the same attention length for each token.
Returns:
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
"""
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype)
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1)
diagonal = tf.linalg.band_part(ones_matrix, 0, 0)
padding = tf.zeros([seq_length, memory_length], dtype=dtype)
causal_attention_mask = tf.concat(
[padding, upper_triangular - diagonal], 1)
if same_length:
lower_triangular = tf.linalg.band_part(ones_matrix, -1, 0)
strictly_lower_triangular = lower_triangular - diagonal
causal_attention_mask = tf.concat(
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular,
causal_attention_mask[:, seq_length:]], 1)
return causal_attention_mask
def _compute_attention_mask(
input_mask,
permutation_mask,
attention_type,
seq_length,
memory_length,
batch_size,
dtype=tf.float32):
"""Combines all input attention masks for XLNet.
In XLNet modeling, `0` represents tokens that can be attended, and `1`
represents tokens that cannot be attended.
For XLNet pre-training and fine tuning, there are a few masks used:
- Causal attention mask: If the attention type is unidirectional, then all
tokens after the current position cannot be attended to.
- Input mask: when generating data, padding is added to a max sequence length
to make all sequences the same length. This masks out real tokens (`0`) from
padding tokens (`1`).
- Permutation mask: during XLNet pretraining, the input sequence is factorized
into a factorization sequence `z`. During partial prediction, `z` is split
at a cutting point `c` (an index of the factorization sequence) and
prediction is only applied to all tokens after `c`. Therefore, tokens at
factorization positions `i` > `c` can be attended to and tokens at
factorization positions `i` <= `c` cannot be attended to.
This function broadcasts and combines all attention masks to produce the
query attention mask and the content attention mask.
Args:
input_mask: Tensor, the input mask related to padding. Input shape:
`(B, S)`.
permutation_mask: Tensor, the permutation mask used in partial prediction.
Input shape: `(B, S, S)`.
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
seq_length: int, the length of each sequence.
memory_length: int the length of memory blocks.
batch_size: int, the batch size.
dtype: The dtype of the masks.
Returns:
attention_mask, content_attention_mask: The position and context-based
attention masks and content attention masks, respectively.
"""
attention_mask = None
# `1` values mean do not attend to this position.
if attention_type == "uni":
causal_attention_mask = _create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length,
dtype=dtype)
causal_attention_mask = causal_attention_mask[None, None, :, :]
# `causal_attention_mask`: [1, 1, S, S + M]
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if input_mask is not None and permutation_mask is not None:
data_mask = input_mask[:, None, :] + permutation_mask
elif input_mask is not None and permutation_mask is None:
data_mask = input_mask[:, None, :]
elif input_mask is None and permutation_mask is not None:
data_mask = permutation_mask
else:
data_mask = None
# data_mask: [B, S, S] or [B, 1, S]
if data_mask is not None:
# All positions within state can be attended to.
state_mask = tf.zeros([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype)
# state_mask: [B, 1, M] or [B, S, M]
data_mask = tf.concat([state_mask, data_mask], 2)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if attention_type == "uni":
attention_mask = causal_attention_mask + data_mask[:, None, :, :]
else:
attention_mask = data_mask[:, None, :, :]
# Construct the content attention mask.
if attention_mask is not None:
attention_mask = tf.cast(attention_mask > 0, dtype=dtype)
non_tgt_mask = -tf.eye(seq_length, dtype=dtype)
non_tgt_mask = tf.concat(
[tf.zeros([seq_length, memory_length], dtype=dtype),
non_tgt_mask], axis=-1)
content_attention_mask = tf.cast(
(attention_mask + non_tgt_mask[None, None, :, :]) > 0,
dtype=dtype)
else:
content_attention_mask = None
return attention_mask, content_attention_mask
def _compute_segment_matrix(
segment_ids,
memory_length,
batch_size,
use_cls_mask):
"""Computes the segment embedding matrix.
XLNet introduced segment-based attention for attention calculations. This
extends the idea of relative encodings in Transformer XL by considering
whether or not two positions are within the same segment, rather than
which segments they come from.
This function generates a segment matrix by broadcasting provided segment IDs
in two different dimensions and checking where values are equal. This output
matrix shows `True` whenever two tokens are NOT in the same segment and
`False` whenever they are.
Args:
segment_ids: A Tensor of size `[B, S]` that represents which segment
each token belongs to.
memory_length: int, the length of memory blocks.
batch_size: int, the batch size.
use_cls_mask: bool, whether or not to introduce cls mask in
input sequences.
Returns:
A boolean Tensor of size `[B, S, S + M]`, where `True` means that two
tokens are NOT in the same segment, and `False` means they are in the same
segment.
"""
if segment_ids is None:
return None
memory_padding = tf.zeros([batch_size, memory_length], dtype=tf.int32)
padded_segment_ids = tf.concat([memory_padding, segment_ids], 1)
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
if use_cls_mask:
# `1` indicates not in the same segment.
# Target result: [B, S, S + M]
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
broadcasted_segment_class_indices = (
tf.equal(segment_ids,
tf.constant([_SEG_ID_CLS]))[:, :, None])
broadcasted_padded_class_indices = (
tf.equal(
padded_segment_ids,
tf.constant([_SEG_ID_CLS]))[:, None, :])
class_index_matrix = tf.logical_or(broadcasted_segment_class_indices,
broadcasted_padded_class_indices)
segment_matrix = tf.equal(segment_ids[:, :, None],
padded_segment_ids[:, None, :])
segment_matrix = tf.logical_or(class_index_matrix, segment_matrix)
else:
# TODO(allencwang) - address this legacy mismatch from `use_cls_mask`.
segment_matrix = tf.logical_not(
tf.equal(segment_ids[:, :, None], padded_segment_ids[:, None, :]))
return segment_matrix
def _compute_positional_encoding(
attention_type,
position_encoding_layer,
hidden_size,
batch_size,
total_length,
seq_length,
clamp_length,
bi_data,
dtype=tf.float32):
"""Computes the relative position encoding.
Args:
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
position_encoding_layer: An instance of `RelativePositionEncoding`.
hidden_size: int, the hidden size.
batch_size: int, the batch size.
total_length: int, the sequence length added to the memory length.
seq_length: int, the length of each sequence.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
dtype: the dtype of the encoding.
Returns:
A Tensor, representing the position encoding.
"""
freq_seq = tf.range(0, hidden_size, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=dtype)
if attention_type == "bi":
beg, end = total_length, -seq_length
elif attention_type == "uni":
beg, end = total_length, -1
else:
raise ValueError("Unknown `attention_type` {}.".format(attention_type))
if bi_data:
forward_position_sequence = tf.range(beg, end, -1.0)
backward_position_sequence = tf.range(-beg, -end, 1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(forward_position_sequence,
dtype=dtype)
backward_position_sequence = tf.cast(backward_position_sequence,
dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
backward_position_sequence = tf.clip_by_value(
backward_position_sequence,
-clamp_length,
clamp_length)
if batch_size is not None:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, batch_size // 2)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, batch_size // 2)
else:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, None)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, None)
relative_position_encoding = tf.concat(
[forward_positional_encoding, backward_positional_encoding], axis=0)
else:
forward_position_sequence = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(
forward_position_sequence, dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
relative_position_encoding = position_encoding_layer(
forward_position_sequence, batch_size)
return relative_position_encoding
class RelativePositionEncoding(tf.keras.layers.Layer):
"""Creates a relative positional encoding.
This layer creates a relative positional encoding as described in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
Rather than an absolute position embedding as in Transformer, this
formulation represents position as the relative distance between tokens using
sinusoidal positional embeddings.
Note: This layer is currently experimental.
Attributes:
hidden_size: The dimensionality of the input embeddings.
"""
def __init__(self, hidden_size, **kwargs):
super(RelativePositionEncoding, self).__init__(**kwargs)
self._hidden_size = hidden_size
self._inv_freq = 1.0 / (10000.0**(
tf.range(0, self._hidden_size, 2.0) / self._hidden_size))
def call(self, pos_seq, batch_size=None):
"""Implements call() for the layer.
Arguments:
pos_seq: A 1-D `Tensor`
batch_size: The optionally provided batch size that tiles the relative
positional encoding.
Returns:
The relative positional encoding of shape:
[batch_size, len(pos_seq), hidden_size] if batch_size is provided, else
[1, len(pos_seq), hidden_size].
"""
sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq)
relative_position_encoding = tf.concat([tf.sin(sinusoid_input),
tf.cos(sinusoid_input)], -1)
relative_position_encoding = relative_position_encoding[None, :, :]
if batch_size is not None:
relative_position_encoding = tf.tile(relative_position_encoding,
[batch_size, 1, 1])
return relative_position_encoding
@tf.keras.utils.register_keras_serializable(package="Text")
class XLNetBase(tf.keras.layers.Layer):
"""Base XLNet model.
Attributes:
vocab_size: int, the number of tokens in vocabulary.
num_layers: int, the number of layers.
hidden_size: int, the hidden size.
num_attention_heads: int, the number of attention heads.
head_size: int, the dimension size of each attention head.
inner_size: int, the hidden size in feed-forward layers.
dropout_rate: float, dropout rate.
attention_dropout_rate: float, dropout rate on attention probabilities.
attention_type: str, "uni" or "bi".
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
initializer: A tf initializer.
two_stream: bool, whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
tie_attention_biases: bool, whether or not to tie the biases together.
Usually set to `True`. Used for backwards compatibility.
memory_length: int, the number of tokens to cache.
same_length: bool, whether to use the same attention length for each
token.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
reuse_length: int, the number of tokens in the currect batch to be cached
and reused in the future.
inner_activation: str, "relu" or "gelu".
use_cls_mask: bool, whether or not cls mask is included in the
input sequences.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized
into two matrices in the shape of ["vocab_size", "embedding_width"] and
["embedding_width", "hidden_size"] ("embedding_width" is usually much
smaller than "hidden_size").
embedding_layer: The word embedding layer. `None` means we will create a
new embedding layer. Otherwise, we will reuse the given embedding layer.
This parameter is originally added for ELECTRA model which needs to tie
the generator embeddings with the discriminator embeddings.
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
attention_type,
bi_data,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
clamp_length=-1,
reuse_length=None,
inner_activation="relu",
use_cls_mask=False,
embedding_width=None,
**kwargs):
super(XLNetBase, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._attention_type = attention_type
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
self._bi_data = bi_data
self._clamp_length = clamp_length
self._use_cls_mask = use_cls_mask
self._segment_embedding = None
self._mask_embedding = None
self._embedding_width = embedding_width
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=embedding_width,
initializer=self._initializer,
dtype=tf.float32,
name="word_embedding")
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.embedding_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.position_encoding = RelativePositionEncoding(self._hidden_size)
self._transformer_xl = transformer_xl.TransformerXL(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
head_size=head_size,
inner_size=inner_size,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=tie_attention_biases,
memory_length=memory_length,
reuse_length=reuse_length,
inner_activation=inner_activation,
name="transformer_xl")
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"attention_type":
self._attention_type,
"bi_data":
self._bi_data,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"clamp_length":
self._clamp_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
"use_cls_mask":
self._use_cls_mask,
"embedding_width":
self._embedding_width,
}
base_config = super(XLNetBase, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_embedding_lookup_table(self):
"""Returns the embedding layer weights."""
return self._embedding_layer.embeddings
def __call__(self,
input_ids,
segment_ids=None,
input_mask=None,
state=None,
permutation_mask=None,
target_mapping=None,
masked_tokens=None,
**kwargs):
# Uses dict to feed inputs into call() in order to keep state as a python
# list.
inputs = {
"input_ids": input_ids,
"segment_ids": segment_ids,
"input_mask": input_mask,
"state": state,
"permutation_mask": permutation_mask,
"target_mapping": target_mapping,
"masked_tokens": masked_tokens
}
return super(XLNetBase, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
input_ids = inputs["input_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
state = inputs["state"]
permutation_mask = inputs["permutation_mask"]
target_mapping = inputs["target_mapping"]
masked_tokens = inputs["masked_tokens"]
batch_size = tf.shape(input_ids)[0]
seq_length = input_ids.shape.as_list()[1]
memory_length = state[0].shape.as_list()[1] if state is not None else 0
total_length = memory_length + seq_length
if self._two_stream and masked_tokens is None:
raise ValueError("`masked_tokens` must be provided in order to "
"initialize the query stream in "
"`TwoStreamRelativeAttention`.")
if masked_tokens is not None and not self._two_stream:
logging.warning("`masked_tokens` is provided but `two_stream` is not "
"enabled. Please enable `two_stream` to enable two "
"stream attention.")
query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type=self._attention_type,
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
relative_position_encoding = _compute_positional_encoding(
attention_type=self._attention_type,
position_encoding_layer=self.position_encoding,
hidden_size=self._hidden_size,
batch_size=batch_size,
total_length=total_length,
seq_length=seq_length,
clamp_length=self._clamp_length,
bi_data=self._bi_data,
dtype=tf.float32)
relative_position_encoding = self.embedding_dropout(
relative_position_encoding)
if segment_ids is None:
segment_embedding = None
segment_matrix = None
else:
if self._segment_embedding is None:
self._segment_embedding = self.add_weight(
"seg_embed",
shape=[self._num_layers, 2, self._num_attention_heads,
self._head_size],
dtype=tf.float32,
initializer=self._initializer)
segment_embedding = self._segment_embedding
segment_matrix = _compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=self._use_cls_mask)
word_embeddings = self._embedding_layer(input_ids)
content_stream = self._dropout(word_embeddings)
if self._two_stream:
if self._mask_embedding is None:
self._mask_embedding = self.add_weight(
"mask_emb/mask_emb",
shape=[1, 1, self._hidden_size],
dtype=tf.float32)
if target_mapping is None:
masked_tokens = masked_tokens[:, :, None]
masked_token_embedding = (
masked_tokens * self._mask_embedding +
(1 - masked_tokens) * word_embeddings)
else:
masked_token_embedding = tf.tile(
self._mask_embedding,
[batch_size, tf.shape(target_mapping)[1], 1])
query_stream = self._dropout(masked_token_embedding)
else:
query_stream = None
return self._transformer_xl(
content_stream=content_stream,
query_stream=query_stream,
target_mapping=target_mapping,
state=state,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_embedding=segment_embedding,
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import xlnet_base
@keras_parameterized.run_all_keras_modes
class RelativePositionEncodingTest(keras_parameterized.TestCase):
def test_positional_embedding(self):
"""A low-dimensional example is tested.
With len(pos_seq)=2 and d_model=4:
pos_seq = [[1.], [0.]]
inv_freq = [1., 0.01]
pos_seq x inv_freq = [[1, 0.01], [0., 0.]]
pos_emb = [[sin(1.), sin(0.01), cos(1.), cos(0.01)],
[sin(0.), sin(0.), cos(0.), cos(0.)]]
= [[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]
"""
target = np.array([[[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]])
hidden_size = 4
pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
encoding_layer = xlnet_base.RelativePositionEncoding(
hidden_size=hidden_size)
encoding = encoding_layer(pos_seq, batch_size=None).numpy().astype(float)
self.assertAllClose(encoding, target)
class ComputePositionEncodingTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
attention_type=["uni", "bi"],
bi_data=[False, True],
))
def test_compute_position_encoding_smoke(self, attention_type, bi_data):
hidden_size = 4
batch_size = 4
total_length = 8
seq_length = 4
position_encoding_layer = xlnet_base.RelativePositionEncoding(
hidden_size=hidden_size)
encoding = xlnet_base._compute_positional_encoding(
attention_type=attention_type,
position_encoding_layer=position_encoding_layer,
hidden_size=hidden_size,
batch_size=batch_size,
total_length=total_length,
seq_length=seq_length,
clamp_length=2,
bi_data=bi_data,
dtype=tf.float32)
self.assertEqual(encoding.shape[0], batch_size)
self.assertEqual(encoding.shape[2], hidden_size)
class CausalAttentionMaskTests(tf.test.TestCase):
def test_casual_attention_mask_with_no_memory(self):
seq_length, memory_length = 3, 0
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 1, 1],
[0, 0, 1],
[0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_casual_attention_mask_with_memory(self):
seq_length, memory_length = 3, 2
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_causal_attention_mask_with_same_length(self):
seq_length, memory_length = 3, 2
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length,
same_length=True)
expected_output = np.array([[0, 0, 0, 1, 1],
[1, 0, 0, 0, 1],
[1, 1, 0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
class MaskComputationTests(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
use_input_mask=[False, True],
use_permutation_mask=[False, True],
attention_type=["uni", "bi"],
memory_length=[0, 4],
))
def test_compute_attention_mask_smoke(self,
use_input_mask,
use_permutation_mask,
attention_type,
memory_length):
"""Tests coverage and functionality for different configurations."""
batch_size = 2
seq_length = 8
if use_input_mask:
input_mask = tf.zeros(shape=(batch_size, seq_length))
else:
input_mask = None
if use_permutation_mask:
permutation_mask = tf.zeros(shape=(batch_size, seq_length, seq_length))
else:
permutation_mask = None
_, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type=attention_type,
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
expected_mask_shape = (batch_size, 1,
seq_length, seq_length + memory_length)
if use_input_mask or use_permutation_mask:
self.assertEqual(content_mask.shape, expected_mask_shape)
def test_no_input_masks(self):
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=None,
permutation_mask=None,
attention_type="uni",
seq_length=8,
memory_length=2,
batch_size=2,
dtype=tf.float32)
self.assertIsNone(query_mask)
self.assertIsNone(content_mask)
def test_input_mask_no_permutation(self):
"""Tests if an input mask is provided but not permutation.
In the case that only one of input mask or permutation mask is provided
and the attention type is bidirectional, the query mask should be
a broadcasted version of the provided mask.
Content mask should be a broadcasted version of the query mask, where the
diagonal is 0s.
"""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
permutation_mask = None
expected_query_mask = input_mask[None, None, :, :]
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_mask_no_input_mask(self):
"""Tests if a permutation mask is provided but not input."""
seq_length = 2
batch_size = 1
memory_length = 0
input_mask = None
permutation_mask = np.array([
[[0, 1],
[0, 1]],
])
expected_query_mask = permutation_mask[:, None, :, :]
expected_content_mask = np.array([[[
[0, 1],
[0, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_and_input_mask(self):
"""Tests if both an input and permutation mask are provided."""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]])
expected_query_mask = np.array([[[
[1, 0, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 1, 1]]]])
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_input_uni_mask(self):
"""Tests if an input, permutation and causal mask are provided."""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 0, 1]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]])
expected_query_mask = np.array([[[
[1, 1, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1]]]])
expected_content_mask = np.array([[[
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="uni",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
class SegmentMatrixTests(tf.test.TestCase):
def test_no_segment_ids(self):
segment_matrix = xlnet_base._compute_segment_matrix(
segment_ids=None,
memory_length=2,
batch_size=1,
use_cls_mask=False)
self.assertIsNone(segment_matrix)
def test_basic(self):
batch_size = 1
memory_length = 0
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[False, False, True, False],
[False, False, True, False],
[True, True, False, True],
[False, False, True, False]
]])
segment_matrix = xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=False)
self.assertAllClose(segment_matrix, expected_segment_matrix)
def test_basic_with_memory(self):
batch_size = 1
memory_length = 1
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[True, False, False, True, False],
[True, False, False, True, False],
[True, True, True, False, True],
[True, False, False, True, False]
]]).astype(int)
segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=False), dtype=tf.uint8)
self.assertAllClose(segment_matrix, expected_segment_matrix)
def dont_test_basic_with_class_mask(self):
# TODO(allencwang) - this test should pass but illustrates the legacy issue
# of using class mask. Enable once addressed.
batch_size = 1
memory_length = 0
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[False, False, True, False],
[False, False, True, False],
[True, True, False, True],
[False, False, True, False]
]]).astype(int)
segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=True), dtype=tf.uint8)
self.assertAllClose(segment_matrix, expected_segment_matrix)
class XLNetModelTests(tf.test.TestCase):
def _generate_data(self,
batch_size,
seq_length,
num_predictions=None):
"""Generates sample XLNet data for testing."""
sequence_shape = (batch_size, seq_length)
if num_predictions is not None:
target_mapping = tf.random.uniform(
shape=(batch_size, num_predictions, seq_length))
return {
"input_ids": np.random.randint(10, size=sequence_shape, dtype="int32"),
"segment_ids":
np.random.randint(2, size=sequence_shape, dtype="int32"),
"input_mask":
np.random.randint(2, size=sequence_shape).astype("float32"),
"permutation_mask":
np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype("float32"),
"target_mapping": target_mapping,
"masked_tokens": tf.random.uniform(shape=sequence_shape),
}
def test_xlnet_model(self):
batch_size = 2
seq_length = 8
num_predictions = 2
hidden_size = 4
xlnet_model = xlnet_base.XLNetBase(
vocab_size=32000,
num_layers=2,
hidden_size=hidden_size,
num_attention_heads=2,
head_size=2,
inner_size=2,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type="bi",
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
reuse_length=0,
inner_activation="relu")
input_data = self._generate_data(batch_size=batch_size,
seq_length=seq_length,
num_predictions=num_predictions)
model_output = xlnet_model(**input_data)
self.assertEqual(model_output[0].shape,
(batch_size, seq_length, hidden_size))
def test_get_config(self):
xlnet_model = xlnet_base.XLNetBase(
vocab_size=32000,
num_layers=12,
hidden_size=36,
num_attention_heads=12,
head_size=12,
inner_size=12,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type="bi",
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
memory_length=0,
reuse_length=0,
inner_activation="relu")
config = xlnet_model.get_config()
new_xlnet = xlnet_base.XLNetBase.from_config(config)
self.assertEqual(config, new_xlnet.get_config())
if __name__ == "__main__":
tf.random.set_seed(0)
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment