Commit fffea332 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

MultiHeadRelativeAttention compatibility changes with XLNet

PiperOrigin-RevId: 330751568
parent cb6d8d6a
...@@ -20,14 +20,29 @@ import string ...@@ -20,14 +20,29 @@ import string
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense EinsumDense = tf.keras.layers.experimental.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention MultiHeadAttention = tf.keras.layers.MultiHeadAttention
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(tf.keras.layers.MultiHeadAttention): class CachedAttention(tf.keras.layers.MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
...@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention): ...@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
def _rel_shift(x, klen=-1): def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score.""" """Performs relative shift to form the relative attention score."""
x = tf.transpose(x, perm=[1, 2, 0, 3]) x = tf.transpose(x, perm=[2, 3, 0, 1])
x_size = tf.shape(x) x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
x = tf.transpose(x, perm=[2, 0, 1, 3])
x = tf.transpose(x, perm=[2, 3, 0, 1])
return x return x
...@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
to certain positions. to certain positions.
""" """
def __init__(self,
kernel_initializer="variance_scaling",
**kwargs):
super().__init__(kernel_initializer=kernel_initializer,
**kwargs)
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super(MultiHeadRelativeAttention, self)._build_from_signature( super(MultiHeadRelativeAttention, self)._build_from_signature(
query=query, query=query,
value=value, value=value,
key=key) key=key)
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"): if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape) value_shape = tf.TensorShape(value.shape)
else: else:
...@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope(): with tf.init_scope():
free_dims = query_shape.rank - 1 einsum_equation, _, output_rank = _build_proj_equation(
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = EinsumDense( self._encoding_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=None,
name="encoding", name="encoding",
**common_kwargs) **common_kwargs)
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation = "abcd,ecd->abe"
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def _build_attention(self, rank):
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[2])
self._dropout_layer = tf.keras.layers.Dropout(
rate=self._dropout)
def compute_attention(self, def compute_attention(self,
query, query,
key, key,
...@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position, position,
content_attention_bias, content_attention_bias,
positional_attention_bias, positional_attention_bias,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
attention_mask=None): attention_mask=None):
"""Computes the attention. """Computes the attention.
...@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
when calculating the content-based attention score. when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score. head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should logits. If state is not None, the mask source sequence dimension should
extend M. extend M.
Returns: Returns:
attention_output: Multi-headed output of attention computation of shape attention_output: Multi-headed output of attention computation of shape
`[B, T, N, key_dim]`. `[B, S, N, key_dim]`.
""" """
content_attention = tf.einsum("bind,bjnd->bijn", content_attention = tf.einsum(self._dot_product_equation,
query + content_attention_bias, key,
key) query + content_attention_bias)
positional_attention = tf.einsum(self._dot_product_equation,
position,
query + positional_attention_bias)
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[3])
if segment_matrix is not None:
segment_attention = tf.einsum("bind,snd->bnis",
query + segment_attention_bias,
segment_encoding)
target_shape = tf.shape(positional_attention)
segment_attention = tf.where(
tf.broadcast_to(tf.expand_dims(segment_matrix, 1), target_shape),
tf.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
tf.broadcast_to(segment_attention[:, :, :, :1], target_shape))
attention_sum = (
content_attention + positional_attention + segment_attention)
else:
attention_sum = content_attention + positional_attention
positional_attention = tf.einsum("bind,bjnd->bijn", attention_scores = tf.multiply(
query + positional_attention_bias, attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
position)
positional_attention = _rel_shift( # `attention_scores`: `[B, N, S, S + M]`
positional_attention, klen=tf.shape(content_attention)[2]) if attention_mask is not None:
attention_scores += (_large_compatible_negative(attention_scores.dtype)
* attention_mask)
attention_scores = tf.multiply((content_attention + positional_attention), attention_scores = tf.nn.softmax(attention_scores, 3)
1.0 / math.sqrt(float(self._key_dim)))
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_output = self._dropout_layer(attention_scores) attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum("bijn,bjnd->bind", attention_output, value) attention_output = tf.einsum(self._combine_equation,
attention_output,
value)
return attention_output return attention_output
def call(self, def call(self,
...@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
positional_attention_bias, positional_attention_bias,
key=None, key=None,
relative_position_encoding=None, relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None, state=None,
attention_mask=None): attention_mask=None):
"""Compute multi-head relative attention over inputs. """Compute multi-head relative attention over inputs.
...@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
key: attention input. key: attention input.
relative_position_encoding: relative positional encoding for key and relative_position_encoding: relative positional encoding for key and
value. value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended state: (default None) optional state. If passed, this is also attended
over as in TransformerXL. over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention attention_mask: (default None) Optional mask that is added to attention
...@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention): ...@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position=position, position=position,
content_attention_bias=content_attention_bias, content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias, positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=attention_mask) attention_mask=attention_mask)
# `attention_output` = [B, S, N, H]
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment