Commit 2c98b4b0 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement Transformer XL

PiperOrigin-RevId: 332486600
parent c7e31961
......@@ -63,3 +63,24 @@ assemble new layers, networks, or models.
* [GatedFeedforward](gated_feedforward.py) implements the gated linear layer
feedforward as described in
["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202).
* [MultiHeadRelativeAttention](relative_attention.py) implements a variant
of multi-head attention with support for relative position encodings as
described in "Transformer-XL: Attentive Language Models Beyond a
Fixed-Length Context"(https://arxiv.org/abs/1901.02860). This also has
extended support for segment-based attention, a re-parameterization
introduced in "XLNet: Generalized Autoregressive Pretraining for Language
Understanding" (https://arxiv.org/abs/1906.08237).
* [TwoStreamRelativeAttention](relative_attention.py) implements a variant
of multi-head relative attention as described in "XLNet: Generalized
Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). This takes in a query and content
stream and applies self attention.
* [TransformerXL](transformer_xl.py) implements Transformer XL introduced in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860). This contains `TransformerXLBlock`, a
block containing either one or two stream relative self-attention as well as
subsequent feedforward networks. It also contains `TransformerXL`, which
contains attention biases as well as multiple `TransformerXLBlocks`.
......@@ -24,9 +24,13 @@ from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention
from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
from official.nlp.modeling.layers.transformer_xl import TransformerXL
from official.nlp.modeling.layers.transformer_xl import TransformerXLBlock
......@@ -22,6 +22,35 @@ import tensorflow as tf
from official.nlp.modeling.layers import relative_attention
def _cache_memory(current_state, previous_state, memory_length, reuse_length=0):
"""Caches hidden states into memory.
Arguments:
current_state: `Tensor`, the current state.
previous_state: `Tensor`, the previous state.
memory_length: `int`, the number of tokens to cache.
reuse_length: `int`, the number of tokens in the current batch to be cached
and reused in the future.
Returns:
A `Tensor`, representing the cached state with stopped gradients.
"""
if memory_length is None or memory_length == 0:
return None
else:
if reuse_length > 0:
current_state = current_state[:, :reuse_length, :]
if previous_state is None:
new_mem = current_state[:, -memory_length:, :]
else:
new_mem = tf.concat(
[previous_state, current_state], 1)[:, -memory_length:, :]
return tf.stop_gradient(new_mem)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerXLBlock(tf.keras.layers.Layer):
"""Transformer XL block.
......@@ -290,3 +319,243 @@ class TransformerXLBlock(tf.keras.layers.Layer):
return attention_output
class TransformerXL(tf.keras.layers.Layer):
"""Transformer XL.
This layer combines multiple Transformer XL blocks from "Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse
as in Transformer XL and XLNet.
Attributes:
vocab_size: The number of tokens in vocabulary.
num_layers: The number of layers.
hidden_size: The hidden size.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The hidden size in feed-forward layers.
dropout_rate: Dropout rate used in each Transformer XL block.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
initializer: The initializer to use for attention biases.
tie_attention_biases: Whether or not to tie biases together. If `True`, then
each Transformer XL block shares the same trainable attention bias. If
`False`, then each block has its own attention bias. This is usually set
to `True`.
memory_length: The number of tokens to cache.
reuse_length: The number of tokens in the current batch to be cached
and reused in the future.
inner_activation: The activation to use in the inner layers
for Transformer XL blocks. Typically "relu" or "gelu".
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
reuse_length=None,
inner_activation="relu",
**kwargs):
"""Initializes TransformerXL."""
super(TransformerXL, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
if self._tie_attention_biases:
attention_bias_shape = [self._num_attention_heads, self._head_size]
else:
attention_bias_shape = [self._num_layers, self._num_attention_heads,
self._head_size]
self.content_attention_bias = self.add_weight(
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.transformer_xl_layers = []
for i in range(self._num_layers):
self.transformer_xl_layers.append(
TransformerXLBlock(
vocab_size=self._vocab_size,
hidden_size=self._head_size * self._num_attention_heads,
num_attention_heads=self._num_attention_heads,
head_size=self._head_size,
inner_size=self._inner_size,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
norm_epsilon=1e-12,
inner_activation=self._inner_activation,
two_stream=self._two_stream,
kernel_initializer="variance_scaling",
name="layer_%d" % i))
self.output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
}
base_config = super(TransformerXL, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
relative_position_encoding,
segment_matrix=None,
segment_embedding=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements call() for the layer.
Arguments:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is `content_attention` if `two_stream` is `False`,
otherwise it is `query_attention`.
"""
new_mems = []
if state is None:
state = [None] * self._num_layers
for i in range(self._num_layers):
# cache new mems
new_mems.append(
_cache_memory(content_stream, state[i],
self._memory_length, self._reuse_length))
# segment bias
if segment_matrix is None:
segment_attention_bias = None
segment_encoding = None
else:
segment_attention_bias = (self.segment_attention_bias
if self._tie_attention_biases
else self.segment_attention_bias[i])
segment_encoding = segment_embedding[i]
content_attention_bias = (self.content_attention_bias
if self._tie_attention_biases
else self.content_attention_bias[i])
positional_attention_bias = (self.positional_attention_bias
if self._tie_attention_biases
else self.positional_attention_bias[i])
transformer_xl_layer = self.transformer_xl_layers[i]
transformer_xl_output = transformer_xl_layer(
content_stream=content_stream,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state[i],
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask,
query_stream=query_stream,
target_mapping=target_mapping)
content_stream = transformer_xl_output["content_attention"]
if self._two_stream:
query_stream = transformer_xl_output["query_attention"]
else:
query_stream = None
if self._two_stream:
output_stream = query_stream
else:
output_stream = content_stream
return output_stream, new_mems
......@@ -32,6 +32,8 @@ def create_mock_transformer_xl_data(
memory_length=0,
num_predictions=2,
two_stream=False,
num_layers=1,
include_biases=True,
include_state=False,
include_mask=False,
include_segment=False):
......@@ -47,6 +49,8 @@ def create_mock_transformer_xl_data(
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
num_layers: `int`, the number of Transformer XL blocks.
include_biases: optional `bool`, whether or not to include attention biases.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
......@@ -55,27 +59,34 @@ def create_mock_transformer_xl_data(
A dictionary with `str` as keys and `Tensor` as values.
"""
encoding_shape = (batch_size, seq_length * 2, hidden_size)
attention_bias_shape = (num_heads, head_size)
data = dict(
content_stream=tf.random.normal(
shape=(batch_size, seq_length, hidden_size)),
relative_position_encoding=tf.random.normal(shape=encoding_shape),
content_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape))
content_stream=tf.random.normal(
shape=(batch_size, seq_length, hidden_size)))
if include_biases:
attention_bias_shape = (num_heads, head_size)
data.update(dict(
content_attention_bias=tf.random.normal(shape=attention_bias_shape),
segment_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape)))
if two_stream:
two_stream_data = dict(
data.update(dict(
query_stream=tf.random.normal(
shape=(batch_size, num_predictions, hidden_size)),
target_mapping=tf.random.normal(
shape=(batch_size, num_predictions, seq_length)))
data.update(two_stream_data)
shape=(batch_size, num_predictions, seq_length))))
if include_state:
total_seq_length = seq_length + memory_length
data["state"] = tf.random.normal(
shape=(batch_size, memory_length, hidden_size))
if num_layers > 1:
state_shape = (num_layers, batch_size, memory_length, hidden_size)
else:
state_shape = (batch_size, memory_length, hidden_size)
data.update(dict(
state=tf.random.normal(shape=state_shape)))
else:
total_seq_length = seq_length
......@@ -87,15 +98,19 @@ def create_mock_transformer_xl_data(
data["query_attention_mask"] = mask_data
if include_segment:
segment_encoding_shape = (2, num_heads, head_size)
# A transformer XL block takes an individual segment "encoding" from the
# entirety of the Transformer XL segment "embedding".
if num_layers > 1:
segment_encoding_shape = (num_layers, 2, num_heads, head_size)
segment_encoding_name = "segment_embedding"
else:
segment_encoding_shape = (2, num_heads, head_size)
segment_encoding_name = "segment_encoding"
segment_matrix = np.random.randint(
2, size=(batch_size, seq_length, total_seq_length))
segment_matrix = tf.math.equal(segment_matrix, 1)
segment_data = dict(
segment_attention_bias=tf.random.normal(shape=attention_bias_shape),
segment_encoding=tf.random.normal(shape=segment_encoding_shape),
segment_matrix=segment_matrix)
data.update(segment_data)
data["segment_matrix"] = tf.math.equal(segment_matrix, 1)
data[segment_encoding_name] = tf.random.normal(shape=segment_encoding_shape)
return data
......@@ -109,17 +124,19 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_transformer_xl(self,
two_stream,
memory_length,
state,
mask,
segment):
"""Tests combinations of Transformer XL calculations."""
def test_transformer_xl_block(
self,
two_stream,
memory_length,
state,
mask,
segment):
"""Tests combinations of Transformer XL block calculations."""
batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8
hidden_size, num_predictions, inner_size = 24, 8, 12
data = create_mock_transformer_xl_data(
include_biases=True,
num_heads=num_heads,
head_size=head_size,
hidden_size=hidden_size,
......@@ -169,6 +186,90 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
self.assertEqual(transformer_xl_block_config, new_block.get_config())
@keras_parameterized.run_all_keras_modes
class TransformerXLTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
two_stream=[True, False],
memory_length=[0, 4],
reuse_length=[0, 4],
tie_attention_biases=[True, False],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_transformer_xl(
self,
two_stream,
memory_length,
reuse_length,
tie_attention_biases,
state,
mask,
segment):
batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8
hidden_size, num_predictions, inner_size = 24, 8, 12
num_layers = 3
data = create_mock_transformer_xl_data(
include_biases=False,
num_heads=num_heads,
head_size=head_size,
hidden_size=hidden_size,
seq_length=seq_length,
batch_size=batch_size,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=two_stream,
num_layers=num_layers,
include_state=state,
include_mask=mask,
include_segment=segment)
transformer_xl_layer = transformer_xl.TransformerXL(
vocab_size=32000,
num_layers=num_layers,
head_size=head_size,
hidden_size=hidden_size,
num_attention_heads=num_heads,
inner_size=inner_size,
dropout_rate=0.,
attention_dropout_rate=0.,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=two_stream,
tie_attention_biases=tie_attention_biases,
memory_length=memory_length,
reuse_length=reuse_length,
inner_activation="relu")
attention_output, cached_memory_states = transformer_xl_layer(**data)
if two_stream:
self.assertEqual(attention_output.shape,
[batch_size, num_predictions, hidden_size])
else:
self.assertEqual(attention_output.shape,
[batch_size, seq_length, hidden_size])
self.assertEqual(len(cached_memory_states), num_layers)
def test_get_config(self):
transformer_xl_layer = transformer_xl.TransformerXL(
vocab_size=32000,
num_layers=12,
hidden_size=36,
head_size=12,
num_attention_heads=12,
inner_size=12,
dropout_rate=0.,
attention_dropout_rate=0.,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
memory_length=0,
reuse_length=0,
inner_activation="relu")
transformer_xl_config = transformer_xl_layer.get_config()
new_transformer_xl = transformer_xl.TransformerXL.from_config(
transformer_xl_config)
self.assertEqual(transformer_xl_config, new_transformer_xl.get_config())
if __name__ == "__main__":
np.random.seed(0)
tf.random.set_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment