Commit c2582c3e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382846192
parent ef21dabb
......@@ -161,6 +161,7 @@ class KernelEncoderConfig(hyperparams.Config):
redraw: bool = False
is_short_seq: bool = False
begin_kernel: int = 0
scale: Optional[float] = None
@dataclasses.dataclass
......@@ -377,6 +378,7 @@ def build_encoder(config: EncoderConfig,
redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel,
scale=encoder_cfg.scale,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
......
......@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, is_query, f, h,
data_normalizer_fn=None):
def _generalized_kernel(x, projection_matrix, f, h):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Transformed feature.
"""
# No asymmetric operations.
del is_query
if data_normalizer_fn is not None:
x = data_normalizer_fn(x)
if projection_matrix is None:
return h(x) * f(x)
......@@ -139,9 +129,7 @@ _TRANSFORM_MAP = {
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
tf.math.square(x), axis=-1, keepdims=True)),),
"expmod":
functools.partial(
_generalized_kernel,
......@@ -149,15 +137,7 @@ _TRANSFORM_MAP = {
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"l2":
functools.partial(
_generalized_kernel,
f=lambda x: x,
h=lambda x: tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32)),
data_normalizer_fn=lambda x: x),
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),),
"identity": lambda x, projection_matrix, is_query: x
}
# pylint: enable=g-long-lambda
......@@ -170,7 +150,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2
- exp (Lemma 1, positive), relu
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
......@@ -195,14 +175,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw=False,
is_short_seq=False,
begin_kernel=0,
scale=None,
**kwargs):
r"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
feature_transform as "l2".
"identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
......@@ -216,6 +196,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if feature_transform not in _TRANSFORM_MAP:
......@@ -234,8 +216,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference
# 2. no redraw
self._seed = seed
super().__init__(**kwargs)
if scale is None:
self._scale = 1.0 / math.sqrt(float(self._key_dim))
else:
self._scale = scale
self._projection_matrix = None
if num_random_features > 0:
self._projection_matrix = create_projection_matrix(
......@@ -275,6 +260,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
projection_matrix = None
if self._num_random_features > 0:
......@@ -284,23 +280,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else:
projection_matrix = self._projection_matrix
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix, False)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix, True)
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
if is_short_seq:
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key)
......@@ -391,6 +387,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw": self._redraw,
"is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel,
"scale": self._scale,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'l2']
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_REDRAW = [True, False]
_TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment