Commit 3335e25d authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 328225451
parent 20db22c7
......@@ -16,12 +16,16 @@
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
import math
import string
import tensorflow as tf
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -107,3 +111,277 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
if return_attention_scores:
return attention_output, attention_scores, cache
return attention_output, cache
def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
x = tf.transpose(x, perm=[1, 2, 0, 3])
x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
x = tf.transpose(x, perm=[2, 0, 1, 3])
return x
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadRelativeAttention(MultiHeadAttention):
"""A multi-head attention layer with relative attention + position encoding.
This layer shares the same input/output projections as the common
MultiHeadAttention layer.
When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
**Note: This layer is currently experimental.
Arguments:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of shape
`[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def _build_from_signature(self, query, value, key=None):
super(MultiHeadRelativeAttention, self)._build_from_signature(
query=query,
value=value,
key=key)
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="encoding",
**common_kwargs)
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation = "abcd,ecd->abe"
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def _build_attention(self, rank):
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[2])
self._dropout_layer = tf.keras.layers.Dropout(
rate=self._dropout)
def compute_attention(self,
query,
key,
value,
position,
content_attention_bias,
positional_attention_bias,
attention_mask=None):
"""Computes the attention.
This function defines the computation inside `call` with projected
multihead Q, K, V, R inputs.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
content_attention_bias: Trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, T, N, key_dim]`.
"""
content_attention = tf.einsum("bind,bjnd->bijn",
query + content_attention_bias,
key)
positional_attention = tf.einsum("bind,bjnd->bijn",
query + positional_attention_bias,
position)
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[2])
attention_scores = tf.multiply((content_attention + positional_attention),
1.0 / math.sqrt(float(self._key_dim)))
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum("bijn,bjnd->bind", attention_output, value)
return attention_output
def call(self,
query,
value,
content_attention_bias,
positional_attention_bias,
key=None,
relative_position_encoding=None,
state=None,
attention_mask=None):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are projected to the shape specified by `output_shape`.
"""
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
if state is not None and state.shape.ndims > 1:
value = tf.concat([state, value], 1)
key = tf.concat([state, key], 1)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S + M, N, H]
key = self._key_dense(key)
# `value` = [B, S + M, N, H]
value = self._value_dense(value)
# `position` = [B, L, N, H]
position = self._encoding_dense(relative_position_encoding)
attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
attention_mask=attention_mask)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -92,5 +92,38 @@ class CachedAttentionTest(keras_parameterized.TestCase):
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
@keras_parameterized.run_all_keras_modes
class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase):
def test_attention_scores(self):
num_heads = 12
key_dim = 64
value_dim = 32
seq_length = 8
batch_size = 2
test_layer = attention.MultiHeadRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
relative_position_encoding = tf.random.normal(
shape=(batch_size, seq_length * 2, key_dim))
content_attention_bias = tf.random.normal(
shape=(num_heads, key_dim))
positional_attention_bias = tf.random.normal(
shape=(num_heads, key_dim))
output = test_layer(
query=query,
value=value,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
state=None,
attention_mask=None)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
if __name__ == "__main__":
tf.test.main()
......@@ -196,3 +196,4 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
position_embeddings = tf.concat(
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return position_embeddings
......@@ -84,27 +84,49 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
class PositionalEmbedding(tf.keras.layers.Layer):
"""Generates relative positional embeddings used in Transformer-XL and XLNet."""
@tf.keras.utils.register_keras_serializable(package='Text')
class RelativePositionEncoding(tf.keras.layers.Layer):
"""Creates a relative positional encoding.
def __init__(self, dim, **kwargs):
super(PositionalEmbedding, self).__init__(**kwargs)
self.dim = dim
This layer creates a relative positional encoding as described in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
def build(self, unused_input_shapes):
"""Constructs inversed frequency vector for positional embedding layer."""
self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes)
Rather than an absolute position embedding as in Transformer, this
formulation represents position as the relative distance between tokens using
sinusoidal positional embeddings.
def call(self, pos_seq, batch_size):
"""Implements call() for the layer."""
sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
Note: This layer is currently experimental.
Attributes:
hidden_size: The dimensionality of the input embeddings.
"""
def __init__(self, hidden_size, **kwargs):
super(RelativePositionEncoding, self).__init__(**kwargs)
self._hidden_size = hidden_size
self._inv_freq = 1.0 / (10000.0**(
tf.range(0, self._hidden_size, 2.0) / self._hidden_size))
def call(self, pos_seq, batch_size=None):
"""Implements call() for the layer.
Arguments:
pos_seq: A 1-D `Tensor`
batch_size: The optionally provided batch size that tiles the relative
positional encoding.
Returns:
The relative positional encoding of shape:
[len(pos_seq), batch_size, hidden_size] if batch_size is provided, else
[len(pos_seq), 1, hidden_size].
"""
sinusoid_input = tf.einsum('i,d->id', pos_seq, self._inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_input), tf.cos(sinusoid_input)], -1)
pos_emb = pos_emb[:, None, :]
if batch_size is not None:
pos_emb = tf.tile(pos_emb, [1, batch_size, 1])
return pos_emb
......@@ -475,8 +497,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
'mask_emb/mask_emb', shape=[1, 1, self.d_model], dtype=self.tf_float)
self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.fwd_position_embedding = PositionalEmbedding(self.d_model)
self.bwd_position_embedding = PositionalEmbedding(self.d_model)
self.fwd_position_embedding = RelativePositionEncoding(self.d_model)
self.bwd_position_embedding = RelativePositionEncoding(self.d_model)
self.rel_multihead_layers = []
self.h_positionwise_ffn_layers = []
......
......@@ -42,7 +42,7 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
[[0., 0., 1., 1.]]])
d_model = 4
pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
pos_emb_layer = xlnet_modeling.PositionalEmbedding(d_model)
pos_emb_layer = xlnet_modeling.RelativePositionEncoding(d_model)
pos_emb = pos_emb_layer(pos_seq, batch_size=None).numpy().astype(float)
logging.info(pos_emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment