Commit b0ccdb11 authored by Shixin Luo's avatar Shixin Luo
Browse files

resolve conflict with master

parents e61588cd 1611a8c5
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup script."""
import os
from setuptools import find_packages
from setuptools import setup
version = '0.0.1'
def _get_requirements():
"""Parses requirements.txt file."""
install_requires_tmp = []
dependency_links_tmp = []
with open(
os.path.join(os.path.dirname(__file__), './requirements.txt'), 'r') as f:
for line in f:
package_name = line.strip()
# Skip empty line or comments starting with "#".
if not package_name or package_name[0] == '#':
continue
if package_name.startswith('-e '):
dependency_links_tmp.append(package_name[3:].strip())
else:
install_requires_tmp.append(package_name)
return install_requires_tmp, dependency_links_tmp
install_requires, dependency_links = _get_requirements()
install_requires.append('tf-nightly')
setup(
name='keras-nlp',
version=version,
description='Keras Natural Language Processing Library',
url='https://github.com/keras-team/keras-nlp',
author='The Keras authors',
author_email='keras-team@google.com',
license='Apache License 2.0',
install_requires=install_requires,
classifiers=[
'Programming Language :: Python',
'Programming Language :: Python :: 3.6',
'Operating System :: Unix',
'Operating System :: Microsoft :: Windows',
'Operating System :: MacOS',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering',
'Topic :: Software Development'
],
packages=find_packages(exclude=('tests',)),
exclude_package_data={'': ['*_test.py',],},
dependency_links=dependency_links,
python_requires='>=3.6',
)
......@@ -29,7 +29,7 @@ assemble new layers, networks, or models.
described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
* [TransformerDecoderBlock](transformer.py) TransformerDecoderBlock is made up
of self multi-head attention, cross multi-head attention and feedforward
network.
......@@ -63,3 +63,24 @@ assemble new layers, networks, or models.
* [GatedFeedforward](gated_feedforward.py) implements the gated linear layer
feedforward as described in
["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202).
* [MultiHeadRelativeAttention](relative_attention.py) implements a variant
of multi-head attention with support for relative position encodings as
described in "Transformer-XL: Attentive Language Models Beyond a
Fixed-Length Context"(https://arxiv.org/abs/1901.02860). This also has
extended support for segment-based attention, a re-parameterization
introduced in "XLNet: Generalized Autoregressive Pretraining for Language
Understanding" (https://arxiv.org/abs/1906.08237).
* [TwoStreamRelativeAttention](relative_attention.py) implements a variant
of multi-head relative attention as described in "XLNet: Generalized
Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). This takes in a query and content
stream and applies self attention.
* [TransformerXL](transformer_xl.py) implements Transformer XL introduced in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860). This contains `TransformerXLBlock`, a
block containing either one or two stream relative self-attention as well as
subsequent feedforward networks. It also contains `TransformerXL`, which
contains attention biases as well as multiple `TransformerXLBlocks`.
......@@ -24,8 +24,13 @@ from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention
from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
from official.nlp.modeling.layers.transformer_xl import TransformerXL
from official.nlp.modeling.layers.transformer_xl import TransformerXLBlock
......@@ -16,16 +16,11 @@
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
import math
import string
import tensorflow as tf
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -111,277 +106,3 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
if return_attention_scores:
return attention_output, attention_scores, cache
return attention_output, cache
def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
x = tf.transpose(x, perm=[1, 2, 0, 3])
x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
x = tf.transpose(x, perm=[2, 0, 1, 3])
return x
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadRelativeAttention(MultiHeadAttention):
"""A multi-head attention layer with relative attention + position encoding.
This layer shares the same input/output projections as the common
MultiHeadAttention layer.
When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
**Note: This layer is currently experimental.
Arguments:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of shape
`[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def _build_from_signature(self, query, value, key=None):
super(MultiHeadRelativeAttention, self)._build_from_signature(
query=query,
value=value,
key=key)
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="encoding",
**common_kwargs)
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation = "abcd,ecd->abe"
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def _build_attention(self, rank):
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[2])
self._dropout_layer = tf.keras.layers.Dropout(
rate=self._dropout)
def compute_attention(self,
query,
key,
value,
position,
content_attention_bias,
positional_attention_bias,
attention_mask=None):
"""Computes the attention.
This function defines the computation inside `call` with projected
multihead Q, K, V, R inputs.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
content_attention_bias: Trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, T, N, key_dim]`.
"""
content_attention = tf.einsum("bind,bjnd->bijn",
query + content_attention_bias,
key)
positional_attention = tf.einsum("bind,bjnd->bijn",
query + positional_attention_bias,
position)
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[2])
attention_scores = tf.multiply((content_attention + positional_attention),
1.0 / math.sqrt(float(self._key_dim)))
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum("bijn,bjnd->bind", attention_output, value)
return attention_output
def call(self,
query,
value,
content_attention_bias,
positional_attention_bias,
key=None,
relative_position_encoding=None,
state=None,
attention_mask=None):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are projected to the shape specified by `output_shape`.
"""
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
if state is not None and state.shape.ndims > 1:
value = tf.concat([state, value], 1)
key = tf.concat([state, key], 1)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S + M, N, H]
key = self._key_dense(key)
# `value` = [B, S + M, N, H]
value = self._value_dense(value)
# `position` = [B, L, N, H]
position = self._encoding_dense(relative_position_encoding)
attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
attention_mask=attention_mask)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -92,38 +92,5 @@ class CachedAttentionTest(keras_parameterized.TestCase):
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
@keras_parameterized.run_all_keras_modes
class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase):
def test_attention_scores(self):
num_heads = 12
key_dim = 64
value_dim = 32
seq_length = 8
batch_size = 2
test_layer = attention.MultiHeadRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
relative_position_encoding = tf.random.normal(
shape=(batch_size, seq_length * 2, key_dim))
content_attention_bias = tf.random.normal(
shape=(num_heads, key_dim))
positional_attention_bias = tf.random.normal(
shape=(num_heads, key_dim))
output = test_layer(
query=query,
value=value,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
state=None,
attention_mask=None)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
if __name__ == "__main__":
tf.test.main()
......@@ -59,6 +59,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
intermediate_activation,
dropout,
use_gate=True,
apply_output_layer_norm=True,
num_blocks=1,
dropout_position="before_residual",
kernel_initializer="glorot_uniform",
......@@ -75,6 +76,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._dropout = dropout
self._use_gate = use_gate
self._num_blocks = num_blocks
self._apply_output_layer_norm = apply_output_layer_norm
self._dropout_position = dropout_position
if self._dropout_position not in ("before_residual", "after_residual"):
raise ValueError(
......@@ -140,12 +142,13 @@ class GatedFeedforward(tf.keras.layers.Layer):
**common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability.
self._output_layer_norm.append(
tf.keras.layers.LayerNormalization(
name="output_layer_norm_%d" % i,
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
if self._apply_output_layer_norm:
self._output_layer_norm.append(
tf.keras.layers.LayerNormalization(
name="output_layer_norm_%d" % i,
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
def get_config(self):
config = {
......@@ -199,7 +202,8 @@ class GatedFeedforward(tf.keras.layers.Layer):
# add.
if layer_input.dtype == tf.float32:
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm[i](layer_output + layer_input)
if self._apply_output_layer_norm:
layer_output = self._output_layer_norm[i](layer_output + layer_input)
if self._dropout_position == "after_residual":
layer_output = self._output_dropout[i](layer_output)
......
......@@ -14,106 +14,7 @@
# ==============================================================================
"""Masked language model network."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from official.nlp import keras_nlp
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method.
Arguments:
embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name='cls/predictions',
**kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def build(self, input_shape):
self._vocab_size, hidden_size = self.embedding_table.shape
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape = tf_utils.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size, seq_length, width = sequence_shape
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
MaskedLM = keras_nlp.layers.MaskedLM
......@@ -15,78 +15,7 @@
"""Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from official.nlp import keras_nlp
@tf.keras.utils.register_keras_serializable(package="Text")
class OnDeviceEmbedding(tf.keras.layers.Layer):
"""Performs an embedding lookup suitable for accelerator devices.
This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def __init__(self,
vocab_size,
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self.embeddings = self.add_weight(
"embeddings",
shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width**0.5
return embeddings
OnDeviceEmbedding = keras_nlp.layers.OnDeviceEmbedding
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based relative attention layers."""
import math
import string
import tensorflow as tf
_CHR_IDX = string.ascii_lowercase
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
x = tf.transpose(x, perm=[2, 3, 0, 1])
x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
x = tf.transpose(x, perm=[2, 3, 0, 1])
return x
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
"""A multi-head attention layer with relative attention + position encoding.
This layer shares the same input/output projections as the common
MultiHeadAttention layer.
When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
**Note: This layer is currently experimental.
Attributes:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
attention_mask: a boolean mask of shape `[B, T, S]` that prevents attention
to certain positions.
"""
def __init__(self,
kernel_initializer="variance_scaling",
**kwargs):
super().__init__(kernel_initializer=kernel_initializer,
**kwargs)
def _build_from_signature(self, query, value, key=None):
super(MultiHeadRelativeAttention, self)._build_from_signature(
query=query,
value=value,
key=key)
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=None,
name="encoding",
**common_kwargs)
def compute_attention(self,
query,
key,
value,
position,
content_attention_bias,
positional_attention_bias,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
attention_mask=None):
"""Computes the attention.
This function defines the computation inside `call` with projected
multihead Q, K, V, R inputs.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
content_attention_bias: Trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, S, N, key_dim]`.
"""
content_attention = tf.einsum(self._dot_product_equation,
key,
query + content_attention_bias)
positional_attention = tf.einsum(self._dot_product_equation,
position,
query + positional_attention_bias)
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[3])
if segment_matrix is not None:
segment_attention = tf.einsum("bind,snd->bnis",
query + segment_attention_bias,
segment_encoding)
target_shape = tf.shape(positional_attention)
segment_attention = tf.where(
tf.broadcast_to(tf.expand_dims(segment_matrix, 1), target_shape),
tf.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
tf.broadcast_to(segment_attention[:, :, :, :1], target_shape))
attention_sum = (
content_attention + positional_attention + segment_attention)
else:
attention_sum = content_attention + positional_attention
attention_scores = tf.multiply(
attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
# `attention_scores`: `[B, N, S, S + M]`
if attention_mask is not None:
attention_scores += (_large_compatible_negative(attention_scores.dtype)
* attention_mask)
attention_scores = tf.nn.softmax(attention_scores, 3)
attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum(self._combine_equation,
attention_output,
value)
return attention_output
def call(self,
query,
value,
content_attention_bias,
positional_attention_bias,
key=None,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
attention_mask=None):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are projected to the shape specified by `output_shape`.
"""
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
if state is not None and state.shape.ndims > 1:
value = tf.concat([state, value], 1)
key = tf.concat([state, key], 1)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S + M, N, H]
key = self._key_dense(key)
# `value` = [B, S + M, N, H]
value = self._value_dense(value)
# `position` = [B, L, N, H]
position = self._encoding_dense(relative_position_encoding)
attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=attention_mask)
# `attention_output` = [B, S, N, H]
attention_output = self._output_dense(attention_output)
return attention_output
@tf.keras.utils.register_keras_serializable(package="Text")
class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
"""Two-stream relative self-attention for XLNet.
In XLNet, each token has two associated vectors at each self-attention layer,
the content stream (h) and the query stream (g).
The content stream is the self-attention stream as in Transformer XL and
represents the context and content (the token itself).
The query stream only has access to contextual information and the position,
but not the content.
This layer shares the same build signature as `MultiHeadRelativeAttention` but
has different input/output projections.
**Note: This layer is currently experimental.
Call args:
content_stream: `Tensor` of shape `[B, T, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
query_stream: `Tensor` of shape `[B, P, dim]`.
target_mapping: `Tensor` of shape `[B, P, S]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
content_attention_mask: a boolean mask of shape `[B, T, S]` that
prevents attention to certain positions for content attention computation.
query_attention_mask: a boolean mask of shape `[B, T, S]` that
prevents attention to certain position for query attention computation.
"""
def call(self,
content_stream,
content_attention_bias,
positional_attention_bias,
query_stream,
relative_position_encoding,
target_mapping=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_attention_mask=None):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Number of predictions (P): the number of predictions.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in
Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g.
This only has access to contextual information and position, but not
content. If not provided, then this is MultiHeadRelativeAttention with
self-attention.
relative_position_encoding: relative positional encoding for key and
value.
target_mapping: Optional `Tensor` representing the target mapping used
in partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL and XLNet.
content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence
dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to
query attention logits. If state is not None, the mask source sequence
dimension should extend M.
Returns:
content_attention_output, query_attention_output: the results of the
computation, both of shape [B, T, E]. `T` is for target sequence shapes,
`E` is the query input last dimension if `output_shape` is `None`.
Otherwise, the multi-head outputs are projected to the shape specified
by `output_shape`.
"""
if not self._built_from_signature:
self._build_from_signature(content_stream, content_stream, content_stream)
if state is not None and state.shape.ndims > 1:
content_and_memory_stream = tf.concat([state, content_stream], 1)
else:
content_and_memory_stream = content_stream
# `query` = [B, T, N, H]
query = self._query_dense(content_stream)
# `key` = [B, S + M, N, H]
key = self._key_dense(content_and_memory_stream)
# `value` = [B, S + M, N, H]
value = self._value_dense(content_and_memory_stream)
# `position` = [B, L, N, H]
position = self._encoding_dense(relative_position_encoding)
content_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=content_attention_mask)
# `content_attention_output` = [B, S, N, H]
content_attention_output = self._output_dense(content_attention_output)
query_attention_output = None
if query_stream is not None:
query = self._query_dense(query_stream)
if target_mapping is not None:
query = tf.einsum("bmnd,bml->blnd", query, target_mapping)
query_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=query_attention_mask)
query_attention_output = tf.einsum("blnd,bml->bmnd",
query_attention_output,
target_mapping)
else:
query_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=query_attention_mask)
query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import relative_attention
def _create_mock_attention_data(
num_heads,
key_dim,
value_dim,
seq_length,
batch_size,
memory_length=0,
num_predictions=2,
two_stream=False,
include_state=False,
include_mask=False,
include_segment=False):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
seq_length: `int`, Sequence length of the input.
batch_size: `int`, the batch size.
memory_length: optional `int`, the length of the state. Defaults to 0.
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape = (batch_size, seq_length, key_dim)
value_shape = (batch_size, seq_length, value_dim)
encoding_shape = (batch_size, seq_length * 2, key_dim)
attention_bias_shape = (num_heads, key_dim)
data = dict(
relative_position_encoding=tf.random.normal(shape=encoding_shape),
content_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape))
if two_stream:
query_stream_shape = (batch_size, num_predictions, key_dim)
target_mapping_shape = (batch_size, num_predictions, seq_length)
stream_data = dict(
content_stream=tf.random.normal(shape=query_shape),
query_stream=tf.random.normal(shape=query_stream_shape),
target_mapping=tf.random.normal(shape=target_mapping_shape))
else:
stream_data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
data.update(stream_data)
if include_state:
total_seq_length = seq_length + memory_length
state_data = dict(
state=tf.random.normal(shape=(batch_size, memory_length, value_dim)))
data.update(state_data)
else:
total_seq_length = seq_length
if include_mask:
mask_shape = (batch_size, num_heads, seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32")
if two_stream:
mask_data = dict(
content_attention_mask=mask_data,
query_attention_mask=mask_data)
else:
mask_data = dict(attention_mask=mask_data)
data.update(mask_data)
if include_segment:
segment_encoding_shape = (2, num_heads, key_dim)
segment_matrix = np.random.randint(
2, size=(batch_size, seq_length, total_seq_length))
segment_matrix = tf.math.equal(segment_matrix, 1)
segment_data = dict(
segment_attention_bias=tf.random.normal(shape=attention_bias_shape),
segment_encoding=tf.random.normal(shape=segment_encoding_shape),
segment_matrix=segment_matrix)
data.update(segment_data)
return data
@keras_parameterized.run_all_keras_modes
class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
value_dim=[32, 64],
memory_length=[0, 4],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_attention_scores(self,
value_dim,
memory_length,
state,
mask,
segment):
"""Tests combinations of attention score calculations."""
batch_size, num_heads, key_dim, seq_length = 2, 12, 64, 8
test_layer = relative_attention.MultiHeadRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim)
data = _create_mock_attention_data(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim,
seq_length=seq_length,
memory_length=memory_length,
two_stream=False,
batch_size=batch_size,
include_state=state,
include_mask=mask,
include_segment=segment)
output = test_layer(**data)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@keras_parameterized.run_all_keras_modes
class TwoStreamRelativeAttentionTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
num_predictions=[2, 10],
memory_length=[0, 4],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_attention_scores(self,
num_predictions,
memory_length,
state,
mask,
segment):
"""Tests combinations of attention score calculations."""
batch_size, num_heads, key_dim, seq_length = 2, 12, 64, 8
test_layer = relative_attention.TwoStreamRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim)
data = _create_mock_attention_data(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim,
seq_length=seq_length,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=True,
batch_size=batch_size,
include_state=state,
include_mask=mask,
include_segment=segment)
content_output, query_output, = test_layer(**data)
self.assertEqual(content_output.shape, [batch_size, seq_length, key_dim])
self.assertEqual(query_output.shape, [batch_size, num_predictions, key_dim])
if __name__ == "__main__":
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ExpandCondense tensor network layer used in TN-BERT."""
# pylint: disable=g-classes-have-attributes
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf
Layer = tf.keras.layers.Layer
activations = tf.keras.activations
initializers = tf.keras.initializers
@tf.keras.utils.register_keras_serializable(package='Text')
class TNExpandCondense(Layer):
"""A TPU-optimized TensorNetwork layer.
Designed for use in models that currently use Dense layers to achieve
up projection followed by down projection.
This layer is a TPU-optimized combination of 3 operations:
Expand, Apply Activation, and Condense. The layer projects up from
`input_shape[-1]` to `input_shape[-1] * proj_multiplier`, applies
`self.activation`, and then condenses back to `input_shape[-1]`.
Note the input shape and output shape will be identical.
Arguments:
proj_multiplier: Positive integer, multiple of input_shape[-1] to project
up to. Must be one of [2, 4, 6, 8].
use_bias: Boolean, whether the layer uses a bias vector.
activation: Activation function to use between Expand and Condense. If you
don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
kernel_initializer: Initializer for the weight matrices.
bias_initializer: Initializer for the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
Output shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
"""
def __init__(self,
proj_multiplier: int,
use_bias: Optional[bool] = True,
activation: Optional[Text] = 'relu',
kernel_initializer: Optional[Text] = 'glorot_uniform',
bias_initializer: Optional[Text] = 'zeros',
**kwargs) -> None:
# Allow specification of input_dim instead of input_shape,
# for compatability with Keras layers that support this
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(TNExpandCondense, self).__init__(**kwargs)
assert proj_multiplier in [
2, 4, 6, 8, 10, 12
], 'proj_multiplier needs to be one of [2, 4, 6, 8, 10, 12]'
self.proj_multiplier = proj_multiplier
self.use_bias = use_bias
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape: List[int]) -> None:
# Disable the attribute-defined-outside-init violations in this function
# pylint: disable=attribute-defined-outside-init
if input_shape[-1] is None:
raise ValueError(
'The last dimension of the inputs to `TNExpandCondense` '
'should be defined. Found `None`.')
super(TNExpandCondense, self).build(input_shape)
self.proj_size = self.proj_multiplier * input_shape[-1]
assert (self.proj_size // input_shape[-1]) * input_shape[
-1] == self.proj_size, (f'{self.proj_size} / {input_shape[-1]} must be '
f'round')
assert (input_shape[-1] // 128
) * 128 == input_shape[-1], f'{input_shape[-1]} / 128 must be round'
self.w1 = self.add_weight(
name='w1',
shape=(input_shape[-1], input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
self.w2 = self.add_weight(
name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True,
initializer=self.kernel_initializer)
self.w3 = self.add_weight(
name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True,
initializer=self.kernel_initializer)
self.w4 = self.add_weight(
name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
if self.use_bias:
self.bias = self.add_weight(
name='b',
shape=(input_shape[-1] // 128, 1,
128 * (self.proj_size // input_shape[-1])),
trainable=True,
initializer=self.bias_initializer)
else:
self.bias = None
def call(self, inputs: tf.Tensor, **kwargs):
orig_shape = tf.shape(inputs)
input_dim = inputs.shape[-1]
tmp = tf.reshape(inputs, (-1, input_dim))
# Shape is (BatchSeq, input_dim)
# Expansion network
tmp = tf.einsum('ab,Qb->aQ', self.w1, tmp)
# Note: Letter Q will always represent the BatchSeq axis.
tmp = tf.reshape(tmp, (input_dim // 128, 128, -1))
tmp = tf.einsum('abQ,bd->aQd', tmp, self.w2)
# Apply activation and then Condense
tmp = self.activation(tmp + self.bias)
tmp = tf.einsum('aQd,db->aQb', tmp, self.w3)
tmp = tf.einsum('aQb,abd->Qd', tmp, self.w4)
out = tf.reshape(tmp, orig_shape)
return out
def compute_output_shape(self, input_shape: List[int]) -> List[int]:
return input_shape
def get_config(self) -> Dict[Any, Any]:
"""Returns the config of the layer.
The same layer can be reinstantiated later
(without its trained weights) from this configuration.
Returns:
Python dictionary containing the configuration of the layer.
"""
config = {}
# Include the layer-specific arguments
args = ['proj_multiplier', 'use_bias']
for arg in args:
config[arg] = getattr(self, arg)
# Serialize the activation
config['activation'] = activations.serialize(getattr(self, 'activation'))
# Serialize the initializers
decomp_initializers = ['kernel_initializer', 'bias_initializer']
for initializer_arg in decomp_initializers:
config[initializer_arg] = initializers.serialize(
getattr(self, initializer_arg))
# Get base config
base_config = super(TNExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ExpandCondense tensor network layer."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for ExpandCondense TN layer.
"""
def setUp(self):
super(TNLayerTest, self).setUp()
self.labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)
def _build_model(self, data, proj_multiple=2):
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
return model
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
'input_shape': data.shape
},
input_shape=data.shape,
input_data=data,
expected_output_shape=(None, data.shape[-1]),
expected_output_dtype=data.dtype)
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
tf.random.set_seed(0)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
history = model.fit(data, self.labels, epochs=5, batch_size=32)
# Check that loss decreases and accuracy increases
self.assertGreater(history.history['loss'][0], history.history['loss'][-1])
self.assertLess(
history.history['accuracy'][0], history.history['accuracy'][-1])
@parameterized.parameters((768, 6), (1024, 2))
def test_weights_change(self, input_dim, proj_multiple):
tf.random.set_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
before = model.get_weights()
model.fit(data, self.labels, epochs=5, batch_size=32)
after = model.get_weights()
# Make sure every layer's weights changed
for i, _ in enumerate(before):
self.assertTrue((after[i] != before[i]).any())
@parameterized.parameters((768, 6), (1024, 2))
def test_output_shape(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
input_shape = data.shape
actual_output_shape = model(data).shape
expected_output_shape = model.compute_output_shape(input_shape)
self.assertEqual(expected_output_shape, actual_output_shape)
@parameterized.parameters((768, 6), (1024, 2))
def test_expandcondense_num_parameters(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
proj_size = proj_multiple * data.shape[-1]
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
w1_params = data.shape[-1]**2
w2_params = 128 * 128 * (proj_size // data.shape[-1])
w3_params = 128 * 128 * (proj_size // data.shape[-1])
w4_params = (data.shape[-1] // 128) * 128 * data.shape[-1]
bias_params = ((data.shape[-1] // 128) * 128 *
(proj_size // data.shape[-1]))
expected_num_parameters = (w1_params + w2_params + w3_params +
w4_params) + bias_params
self.assertEqual(expected_num_parameters, model.count_params())
@parameterized.parameters((912, 6), (200, 2))
def test_incorrect_sizes(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
with self.assertRaises(AssertionError):
model = self._build_model(data, proj_multiple)
model.compile(optimizer='adam', loss='binary_crossentropy')
@parameterized.parameters((768, 6), (1024, 2))
def test_config(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
expected_num_parameters = model.layers[0].count_params()
# Serialize model and use config to create new layer
model_config = model.get_config()
layer_config = model_config['layers'][1]['config']
new_model = TNExpandCondense.from_config(layer_config)
# Build the layer so we can count params below
new_model.build(layer_config['batch_input_shape'])
# Check that original layer had same num params as layer built from config
self.assertEqual(expected_num_parameters, new_model.count_params())
@parameterized.parameters((768, 6), (1024, 2))
def test_model_save(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32)
save_path = os.path.join(self.get_temp_dir(), 'test_model')
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TN-BERT TNTransformerExpandCondense employing Expand-Condense layer instead of Dense."""
# pylint: disable=g-classes-have-attributes
# Import libraries
import gin
import tensorflow as tf
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TNTransformerExpandCondense(tf.keras.layers.Layer):
"""Transformer layer using tensor network Expand-Condense layer.
This layer implements the Transformer from transformer.py, with a single
tensor network layer replacing the usual intermediate and output Dense
layers.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set to False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TNTransformerExpandCondense, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError(
"TNTransformerExpandCondense expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError(
"When passing a mask tensor to TNTransformerExpandCondense, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
# Substitute Dense layers with a single Expand-Condense layer.
self._output_dense = TNExpandCondense(
4,
use_bias=True,
activation=self._intermediate_activation,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(TNTransformerExpandCondense, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super(TNTransformerExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
layer_output = self._output_dense(attention_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TN-BERT transformer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(('tn', TNTransformerExpandCondense))
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (16 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 256
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
if __name__ == '__main__':
tf.test.main()
......@@ -25,7 +25,7 @@ from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text")
class Transformer(keras_nlp.TransformerEncoderBlock):
class Transformer(keras_nlp.layers.TransformerEncoderBlock):
"""Transformer layer.
This layer implements the Transformer from "Attention Is All You Need".
......@@ -109,7 +109,7 @@ class CompiledTransformer(Transformer):
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
......@@ -163,7 +163,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
super().__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
......@@ -274,7 +274,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
......@@ -315,7 +315,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super(TransformerDecoderLayer, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
......@@ -329,11 +329,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
......
......@@ -82,6 +82,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
feedforward_cfg=None,
dropout_rate=0.0,
attention_dropout_rate=0.0,
norm_first=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -96,6 +97,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_cls = attention_cls
self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
......@@ -115,18 +117,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
raise ValueError(
"TransformerScaffold expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
......@@ -257,11 +248,23 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
......@@ -272,8 +275,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
else:
layer_output = self._feedforward_block(attention_output)
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output)
return layer_output
......@@ -182,30 +182,6 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_creation_with_incorrect_mask_fails(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self):
sequence_length = 21
width = 80
......
......@@ -32,12 +32,12 @@ def _create_cache(batch_size, init_decode_length, num_heads, head_size):
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
class TransformerDecoderBlockTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -56,7 +56,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -77,7 +77,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -90,7 +90,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
new_decoder_block = transformer.TransformerDecoderBlock.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based Transformer XL layer."""
from absl import logging
import tensorflow as tf
from official.nlp.modeling.layers import relative_attention
def _cache_memory(current_state, previous_state, memory_length, reuse_length=0):
"""Caches hidden states into memory.
Arguments:
current_state: `Tensor`, the current state.
previous_state: `Tensor`, the previous state.
memory_length: `int`, the number of tokens to cache.
reuse_length: `int`, the number of tokens in the current batch to be cached
and reused in the future.
Returns:
A `Tensor`, representing the cached state with stopped gradients.
"""
if memory_length is None or memory_length == 0:
return None
else:
if reuse_length > 0:
current_state = current_state[:, :reuse_length, :]
if previous_state is None:
new_mem = current_state[:, -memory_length:, :]
else:
new_mem = tf.concat(
[previous_state, current_state], 1)[:, -memory_length:, :]
return tf.stop_gradient(new_mem)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerXLBlock(tf.keras.layers.Layer):
"""Transformer XL block.
This implements a Transformer XL block from "Transformer-XL: Attentive
Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This block is further extended to allow for the Transformer-XL
re-parameterization in "XLNet: Generalized Autoregressive Pretraining for
Language Understanding" (https://arxiv.org/abs/1906.08237).
Given an input stream, this block computes attention, applies dropouts and
layer norms and feeds into the FFN network.
**Note: This layer is currently experimental.
Attributes:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The inner size for the transformer layers.
dropout_rate: Dropout rate for the output of this layer.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the
XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
norm_epsilon: Epsilon value to initialize normalization layers.
inner_activation: The activation to use for the inner
FFN layers.
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def __init__(self,
vocab_size,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
two_stream=False,
norm_epsilon=1e-12,
inner_activation="relu",
kernel_initializer="variance_scaling",
inner_dropout=0.0,
**kwargs):
"""Initializes TransformerXLBlock layer."""
super(TransformerXLBlock, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._num_heads = num_attention_heads
self._head_size = head_size
self._hidden_size = hidden_size
self._inner_size = inner_size
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._inner_activation = inner_activation
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._inner_dropout = inner_dropout
self._two_stream = two_stream
if two_stream:
self._attention_layer_type = relative_attention.TwoStreamRelativeAttention
else:
self._attention_layer_type = relative_attention.MultiHeadRelativeAttention
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerXLBlock, "
"the mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_layer = self._attention_layer_type(
num_heads=self._num_heads,
key_dim=self._head_size,
value_dim=self._head_size,
dropout=self._attention_dropout_rate,
use_bias=False,
kernel_initializer=self._kernel_initializer,
name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
self._attention_layer_norm = tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
self._inner_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="inner")
self._inner_activation_layer = tf.keras.layers.Activation(
self._inner_activation)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon)
super(TransformerXLBlock, self).build(input_shape)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"two_stream":
self._two_stream,
"norm_epsilon":
self._norm_epsilon,
"inner_activation":
self._inner_activation,
"kernel_initializer":
self._kernel_initializer,
"inner_dropout":
self._inner_dropout,
}
base_config = super(TransformerXLBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
content_attention_bias,
positional_attention_bias,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements `call` for the Layer.
Arguments:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_encoding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
segment_attention_bias: Optional bias `Tensor` for segment based attention
of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A `dict` object, containing the key value pairs for `content_attention`
and (if `two_stream` is `True`) `query_attention`.
"""
if not self._two_stream and query_stream is not None:
logging.warning("`query_stream` was provided but two stream attention is "
"disabled. `query_stream` will be ignored.")
if self._two_stream:
attention_kwargs = dict(
content_stream=content_stream,
query_stream=query_stream,
query_attention_mask=query_attention_mask,
target_mapping=target_mapping,
content_attention_mask=content_attention_mask)
else:
attention_kwargs = dict(
query=content_stream,
value=content_stream,
key=content_stream,
attention_mask=content_attention_mask)
common_attention_kwargs = dict(
content_attention_bias=content_attention_bias,
relative_position_encoding=relative_position_encoding,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state)
attention_kwargs.update(common_attention_kwargs)
attention_output = self._attention_layer(**attention_kwargs)
if self._two_stream:
attention_streams = attention_output
input_streams = [content_stream, query_stream]
else:
attention_streams = [attention_output]
input_streams = [content_stream]
attention_keys = ["content_attention", "query_attention"]
attention_output = {}
for attention_stream, input_stream, attention_key in zip(
attention_streams, input_streams, attention_keys):
attention_stream = self._attention_dropout(attention_stream)
attention_stream = self._attention_layer_norm(
attention_stream + input_stream)
inner_output = self._inner_dense(attention_stream)
inner_output = self._inner_activation_layer(
inner_output)
inner_output = self._inner_dropout_layer(
inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
layer_output = self._output_layer_norm(layer_output + attention_stream)
attention_output[attention_key] = layer_output
return attention_output
class TransformerXL(tf.keras.layers.Layer):
"""Transformer XL.
This layer combines multiple Transformer XL blocks from "Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse
as in Transformer XL and XLNet.
Attributes:
vocab_size: The number of tokens in vocabulary.
num_layers: The number of layers.
hidden_size: The hidden size.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The hidden size in feed-forward layers.
dropout_rate: Dropout rate used in each Transformer XL block.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
initializer: The initializer to use for attention biases.
tie_attention_biases: Whether or not to tie biases together. If `True`, then
each Transformer XL block shares the same trainable attention bias. If
`False`, then each block has its own attention bias. This is usually set
to `True`.
memory_length: The number of tokens to cache.
reuse_length: The number of tokens in the current batch to be cached
and reused in the future.
inner_activation: The activation to use in the inner layers
for Transformer XL blocks. Typically "relu" or "gelu".
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
reuse_length=None,
inner_activation="relu",
**kwargs):
"""Initializes TransformerXL."""
super(TransformerXL, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
if self._tie_attention_biases:
attention_bias_shape = [self._num_attention_heads, self._head_size]
else:
attention_bias_shape = [self._num_layers, self._num_attention_heads,
self._head_size]
self.content_attention_bias = self.add_weight(
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.transformer_xl_layers = []
for i in range(self._num_layers):
self.transformer_xl_layers.append(
TransformerXLBlock(
vocab_size=self._vocab_size,
hidden_size=self._head_size * self._num_attention_heads,
num_attention_heads=self._num_attention_heads,
head_size=self._head_size,
inner_size=self._inner_size,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
norm_epsilon=1e-12,
inner_activation=self._inner_activation,
two_stream=self._two_stream,
kernel_initializer="variance_scaling",
name="layer_%d" % i))
self.output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
}
base_config = super(TransformerXL, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
relative_position_encoding,
segment_matrix=None,
segment_embedding=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements call() for the layer.
Arguments:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is `content_attention` if `two_stream` is `False`,
otherwise it is `query_attention`.
"""
new_mems = []
if state is None:
state = [None] * self._num_layers
for i in range(self._num_layers):
# cache new mems
new_mems.append(
_cache_memory(content_stream, state[i],
self._memory_length, self._reuse_length))
# segment bias
if segment_matrix is None:
segment_attention_bias = None
segment_encoding = None
else:
segment_attention_bias = (self.segment_attention_bias
if self._tie_attention_biases
else self.segment_attention_bias[i])
segment_encoding = segment_embedding[i]
content_attention_bias = (self.content_attention_bias
if self._tie_attention_biases
else self.content_attention_bias[i])
positional_attention_bias = (self.positional_attention_bias
if self._tie_attention_biases
else self.positional_attention_bias[i])
transformer_xl_layer = self.transformer_xl_layers[i]
transformer_xl_output = transformer_xl_layer(
content_stream=content_stream,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state[i],
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask,
query_stream=query_stream,
target_mapping=target_mapping)
content_stream = transformer_xl_output["content_attention"]
if self._two_stream:
query_stream = transformer_xl_output["query_attention"]
else:
query_stream = None
if self._two_stream:
output_stream = query_stream
else:
output_stream = content_stream
return output_stream, new_mems
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment