"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "6d37a3d03f0f6d5286c2d8f6ca10c4429d576377"
Commit c2582c3e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382846192
parent ef21dabb
...@@ -161,6 +161,7 @@ class KernelEncoderConfig(hyperparams.Config): ...@@ -161,6 +161,7 @@ class KernelEncoderConfig(hyperparams.Config):
redraw: bool = False redraw: bool = False
is_short_seq: bool = False is_short_seq: bool = False
begin_kernel: int = 0 begin_kernel: int = 0
scale: Optional[float] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -377,6 +378,7 @@ def build_encoder(config: EncoderConfig, ...@@ -377,6 +378,7 @@ def build_encoder(config: EncoderConfig,
redraw=encoder_cfg.redraw, redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq, is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel, begin_kernel=encoder_cfg.begin_kernel,
scale=encoder_cfg.scale,
) )
hidden_cfg = dict( hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
......
...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None): ...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix) return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, is_query, f, h, def _generalized_kernel(x, projection_matrix, f, h):
data_normalizer_fn=None):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS. """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args: Args:
x: The feature being transformed with shape [B, T, N ,H]. x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections. M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x. f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None. transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns: Returns:
Transformed feature. Transformed feature.
""" """
# No asymmetric operations.
del is_query
if data_normalizer_fn is not None:
x = data_normalizer_fn(x)
if projection_matrix is None: if projection_matrix is None:
return h(x) * f(x) return h(x) * f(x)
...@@ -139,9 +129,7 @@ _TRANSFORM_MAP = { ...@@ -139,9 +129,7 @@ _TRANSFORM_MAP = {
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum( -0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)), tf.math.square(x), axis=-1, keepdims=True)),),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"expmod": "expmod":
functools.partial( functools.partial(
_generalized_kernel, _generalized_kernel,
...@@ -149,15 +137,7 @@ _TRANSFORM_MAP = { ...@@ -149,15 +137,7 @@ _TRANSFORM_MAP = {
f=lambda x: tf.math.exp( f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))), -0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"l2":
functools.partial(
_generalized_kernel,
f=lambda x: x,
h=lambda x: tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32)),
data_normalizer_fn=lambda x: x),
"identity": lambda x, projection_matrix, is_query: x "identity": lambda x, projection_matrix, is_query: x
} }
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
...@@ -170,7 +150,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -170,7 +150,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794) (https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2 - exp (Lemma 1, positive), relu
- random/deterministic projection - random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...@@ -195,14 +175,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -195,14 +175,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw=False, redraw=False,
is_short_seq=False, is_short_seq=False,
begin_kernel=0, begin_kernel=0,
scale=None,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
Args: Args:
feature_transform: A non-linear transform of the keys and quries. feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod", Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose "identity".
feature_transform as "l2".
num_random_features: Number of random features to be used for projection. num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform. if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the seed: The seed to begin drawing random features. Once the seed is set, the
...@@ -216,6 +196,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -216,6 +196,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option). (default option).
begin_kernel: Apply kernel_attention after this sequence id and apply begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this. softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer. **kwargs: The same arguments `MultiHeadAttention` layer.
""" """
if feature_transform not in _TRANSFORM_MAP: if feature_transform not in _TRANSFORM_MAP:
...@@ -234,8 +216,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -234,8 +216,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference # 1. inference
# 2. no redraw # 2. no redraw
self._seed = seed self._seed = seed
super().__init__(**kwargs) super().__init__(**kwargs)
if scale is None:
self._scale = 1.0 / math.sqrt(float(self._key_dim))
else:
self._scale = scale
self._projection_matrix = None self._projection_matrix = None
if num_random_features > 0: if num_random_features > 0:
self._projection_matrix = create_projection_matrix( self._projection_matrix = create_projection_matrix(
...@@ -275,6 +260,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -275,6 +260,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns: Returns:
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
""" """
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
projection_matrix = None projection_matrix = None
if self._num_random_features > 0: if self._num_random_features > 0:
...@@ -284,23 +280,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -284,23 +280,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else: else:
projection_matrix = self._projection_matrix projection_matrix = self._projection_matrix
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix, False) # Note: we suspect spliting the scale to key, query yields smaller
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix, True) # approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if attention_mask is not None: if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask) key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
if is_short_seq: kv = tf.einsum("BSNH,BSND->BNDH", key, value)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key) denominator = 1.0 / (
attention_scores = tf.nn.softmax(attention_scores, axis=2) tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) _NUMERIC_STABLER)
return attention_output return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) super()._build_from_signature(query=query, value=value, key=key)
...@@ -391,6 +387,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -391,6 +387,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw": self._redraw, "redraw": self._redraw,
"is_short_seq": self._is_short_seq, "is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel, "begin_kernel": self._begin_kernel,
"scale": self._scale,
} }
base_config = super().get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'l2'] _FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_REDRAW = [True, False] _REDRAW = [True, False]
_TRAINING = [True, False] _TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False] _IS_SHORT_SEQ = [True, False]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment