Commit 3325205b authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement `TwoStreamRelativeAttention`.

PiperOrigin-RevId: 331787357
parent 13a030ed
...@@ -344,3 +344,182 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -344,3 +344,182 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
@tf.keras.utils.register_keras_serializable(package="Text")
class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
"""Two-stream relative self-attention for XLNet.
In XLNet, each token has two associated vectors at each self-attention layer,
the content stream (h) and the query stream (g).
The content stream is the self-attention stream as in Transformer XL and
represents the context and content (the token itself).
The query stream only has access to contextual information and the position,
but not the content.
This layer shares the same build signature as `MultiHeadRelativeAttention` but
has different input/output projections.
**Note: This layer is currently experimental.
Call args:
content_stream: `Tensor` of shape `[B, T, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of shape
`[num_heads, dim]`.
query_stream: `Tensor` of shape `[B, P, dim]`.
target_mapping: `Tensor` of shape `[B, P, S]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
content_attention_mask: a boolean mask of shape `[B, T, S]` that
prevents attention to certain positions for content attention computation.
query_attention_mask: a boolean mask of shape `[B, T, S]` that
prevents attention to certain position for query attention computation.
"""
def call(self,
content_stream,
content_attention_bias,
positional_attention_bias,
query_stream,
relative_position_encoding,
target_mapping=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_attention_mask=None):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Number of predictions (P): the number of predictions.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in
Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g.
This only has access to contextual information and position, but not
content. If not provided, then this is MultiHeadRelativeAttention with
self-attention.
relative_position_encoding: relative positional encoding for key and
value.
target_mapping: Optional `Tensor` representing the target mapping used
in partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL and XLNet.
content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence
dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to
query attention logits. If state is not None, the mask source sequence
dimension should extend M.
Returns:
content_attention_output, query_attention_output: the results of the
computation, both of shape [B, T, E]. `T` is for target sequence shapes,
`E` is the query input last dimension if `output_shape` is `None`.
Otherwise, the multi-head outputs are projected to the shape specified
by `output_shape`.
"""
if not self._built_from_signature:
self._build_from_signature(content_stream, content_stream, content_stream)
if state is not None and state.shape.ndims > 1:
content_and_memory_stream = tf.concat([state, content_stream], 1)
else:
content_and_memory_stream = content_stream
# `query` = [B, T, N, H]
query = self._query_dense(content_stream)
# `key` = [B, S + M, N, H]
key = self._key_dense(content_and_memory_stream)
# `value` = [B, S + M, N, H]
value = self._value_dense(content_and_memory_stream)
# `position` = [B, L, N, H]
position = self._encoding_dense(relative_position_encoding)
content_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=content_attention_mask)
# `content_attention_output` = [B, S, N, H]
content_attention_output = self._output_dense(content_attention_output)
query_attention_output = None
if query_stream is not None:
query = self._query_dense(query_stream)
if target_mapping is not None:
query = tf.einsum("bmnd,bml->blnd", query, target_mapping)
query_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=query_attention_mask)
query_attention_output = tf.einsum("blnd,bml->bmnd",
query_attention_output,
target_mapping)
else:
query_attention_output = self.compute_attention(
query=query,
key=key,
value=value,
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=query_attention_mask)
query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output
...@@ -29,6 +29,8 @@ def _create_mock_attention_data( ...@@ -29,6 +29,8 @@ def _create_mock_attention_data(
seq_length, seq_length,
batch_size, batch_size,
memory_length=0, memory_length=0,
num_predictions=2,
two_stream=False,
include_state=False, include_state=False,
include_mask=False, include_mask=False,
include_segment=False): include_segment=False):
...@@ -41,6 +43,9 @@ def _create_mock_attention_data( ...@@ -41,6 +43,9 @@ def _create_mock_attention_data(
seq_length: `int`, Sequence length of the input. seq_length: `int`, Sequence length of the input.
batch_size: `int`, the batch size. batch_size: `int`, the batch size.
memory_length: optional `int`, the length of the state. Defaults to 0. memory_length: optional `int`, the length of the state. Defaults to 0.
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
include_state: optional `bool`, whether or not to include state data. include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data. include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data. include_segment: optional `bool`, whether or not to include segment data.
...@@ -54,13 +59,25 @@ def _create_mock_attention_data( ...@@ -54,13 +59,25 @@ def _create_mock_attention_data(
attention_bias_shape = (num_heads, key_dim) attention_bias_shape = (num_heads, key_dim)
data = dict( data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape),
relative_position_encoding=tf.random.normal(shape=encoding_shape), relative_position_encoding=tf.random.normal(shape=encoding_shape),
content_attention_bias=tf.random.normal(shape=attention_bias_shape), content_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape)) positional_attention_bias=tf.random.normal(shape=attention_bias_shape))
if two_stream:
query_stream_shape = (batch_size, num_predictions, key_dim)
target_mapping_shape = (batch_size, num_predictions, seq_length)
stream_data = dict(
content_stream=tf.random.normal(shape=query_shape),
query_stream=tf.random.normal(shape=query_stream_shape),
target_mapping=tf.random.normal(shape=target_mapping_shape))
else:
stream_data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
data.update(stream_data)
if include_state: if include_state:
total_seq_length = seq_length + memory_length total_seq_length = seq_length + memory_length
state_data = dict( state_data = dict(
...@@ -71,9 +88,15 @@ def _create_mock_attention_data( ...@@ -71,9 +88,15 @@ def _create_mock_attention_data(
if include_mask: if include_mask:
mask_shape = (batch_size, num_heads, seq_length, total_seq_length) mask_shape = (batch_size, num_heads, seq_length, total_seq_length)
mask_data = dict( mask_data = np.random.randint(2, size=mask_shape).astype("float32")
attention_mask=np.random.randint(2, size=mask_shape).astype("float32")) if two_stream:
mask_data = dict(
content_attention_mask=mask_data,
query_attention_mask=mask_data)
else:
mask_data = dict(attention_mask=mask_data)
data.update(mask_data) data.update(mask_data)
if include_segment: if include_segment:
segment_encoding_shape = (2, num_heads, key_dim) segment_encoding_shape = (2, num_heads, key_dim)
segment_matrix = np.random.randint( segment_matrix = np.random.randint(
...@@ -115,6 +138,7 @@ class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase): ...@@ -115,6 +138,7 @@ class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase):
value_dim=value_dim, value_dim=value_dim,
seq_length=seq_length, seq_length=seq_length,
memory_length=memory_length, memory_length=memory_length,
two_stream=False,
batch_size=batch_size, batch_size=batch_size,
include_state=state, include_state=state,
include_mask=mask, include_mask=mask,
...@@ -123,6 +147,44 @@ class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase): ...@@ -123,6 +147,44 @@ class MultiHeadRelativeAttentionTest(keras_parameterized.TestCase):
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@keras_parameterized.run_all_keras_modes
class TwoStreamRelativeAttentionTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
num_predictions=[2, 10],
memory_length=[0, 4],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_attention_scores(self,
num_predictions,
memory_length,
state,
mask,
segment):
"""Tests combinations of attention score calculations."""
batch_size, num_heads, key_dim, seq_length = 2, 12, 64, 8
test_layer = relative_attention.TwoStreamRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim)
data = _create_mock_attention_data(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim,
seq_length=seq_length,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=True,
batch_size=batch_size,
include_state=state,
include_mask=mask,
include_segment=segment)
content_output, query_output, = test_layer(**data)
self.assertEqual(content_output.shape, [batch_size, seq_length, key_dim])
self.assertEqual(query_output.shape, [batch_size, num_predictions, key_dim])
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0) np.random.seed(0)
tf.random.set_seed(0) tf.random.set_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment