Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,8 @@ import functools
import math
import tensorflow as tf
from official.modeling import tf_utils
_NUMERIC_STABLER = 1e-6
......@@ -39,6 +41,236 @@ class KernelMask(tf.keras.layers.Layer):
return mask
def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
Args:
tensor: Input tensor to pad.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
padding: Pad the input tensor across the axis from either left or right if
padding is set to "left" or "right"; applies no padding if padding is set
to None. In the latter case, the axis dimension of the input tensor must
be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
if padding is None:
return tensor
shape = tf.shape(tensor)
rank = tf.rank(tensor)
if axis < 0:
axis += rank
axis_length = shape[axis]
pad_length = -axis_length % chunk_length
if padding == "right":
axis_paddings = [[0, pad_length]]
elif padding == "left":
axis_paddings = [[pad_length, 0]]
else:
raise ValueError(
"Illegal padding value; must be one of \"left\", \"right\" or None.")
paddings = tf.concat([
tf.zeros([axis, 2], dtype=tf.int32), axis_paddings,
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)
],
axis=0)
return tf.pad(tensor, paddings)
def split_tensor_into_chunks(tensor, axis, chunk_length):
"""Reshape tensor along given axis using chunk_length.
Args:
tensor: Input tensor.
axis: Reshape tensor along this axis.
chunk_length: Split the axis into [axis/chunk_length, chunk_length]
Returns:
Reshaped tensor.
"""
shape = tf.shape(tensor)
num_chunks = shape[axis] // chunk_length
new_shape = tf.concat(
[shape[:axis], [num_chunks, chunk_length], shape[(axis + 1):]], axis=0)
return tf.reshape(tensor, new_shape)
def rectangular_window_sum(tensor, window_length):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum = tf.cumsum(tensor, axis=-4)
tensor_winsum = tensor_cumsum - tf.pad(
tensor_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
return tensor_winsum
def weighted_window_sum(tensor, window_length, window_weights):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape = tf.shape(tensor)
tensor_2d = tf.reshape(tensor, [tensor_shape[0], tensor_shape[1], 1, -1])
# Apply the same weights to all channels.
conv_filter = tf.tile(
tf.reshape(window_weights, [-1, 1, 1, 1]),
multiples=[1, 1, tf.shape(tensor_2d)[-1], 1])
tensor_winsum_2d = tf.nn.depthwise_conv2d(
tensor_2d,
conv_filter,
strides=[1, 1, 1, 1],
padding=[[0, 0], [window_length - 1, 0], [0, 0], [0, 0]])
# Unflatten the channels dimension into the original shape.
tensor_winsum = tf.reshape(tensor_winsum_2d, tensor_shape)
return tensor_winsum
def causal_windowed_performer_attention(query_matrix,
key_matrix,
value_matrix,
chunk_length,
window_length,
window_decay=None,
padding=None,
cache=None):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of
chunk_length tokens (thus: T = N * chunk_length). Within each chunk,
we apply bidirectional (non-causal) Performers’ implicit attention
and we model relationships between different chunks using
Performers’ causal attention. We consider windowed causal variant of
performer, where the current chunk attends only to the window of
window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=2. In
this example 1 indicates attention is computed between the pair
while 0 indicates attention is not computed between the pairs:
111000000
111000000
111000000
111111000
111111000
111111000
000111111
000111111
000111111
User can ensure sequence_length is divisible by chunk_length or use
padding="left"/"right" to pad the sequence length either at the left
or right respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, H, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, H, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set, exponentially
decay past attention window values by this factor before summation.
padding: Pad the query, value and key input tensors across the axis from
either left or right if padding is set to "left" or "right"; apply no
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
if cache is None: # Training
old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding)
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding)
new_shape = tf.shape(value_matrix)
chunked_query_matrix = split_tensor_into_chunks(
query_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix = split_tensor_into_chunks(
key_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix = split_tensor_into_chunks(
value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix,
chunked_value_matrix)
k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True)
if window_decay is None:
kp_v_winsum = rectangular_window_sum(kp_v, window_length)
k_winsum = rectangular_window_sum(k_sum, window_length)
else:
# Compute exponentially decaying weights.
decaying_weights = tf.math.pow(
tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype),
tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype))
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)
numerator = tf.einsum(
"BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)
k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator
attention = tf.reshape(attention, new_shape)
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
attention = tf.slice(attention, start, old_shape)
# Queued window cache (drop instead of decay) not yet supported.
else: # Streaming
if window_decay is None or window_decay > 1.0 or window_decay < 0.0:
raise ValueError("window_decay should be in (0.0, 1.0) and not None.")
kv = window_decay * cache["kv"] + tf.einsum(
"BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv
k_sum = window_decay * cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
return attention
def create_projection_matrix(m, d, seed=None):
r"""Constructs the matrix of random projections.
......@@ -56,8 +288,8 @@ def create_projection_matrix(m, d, seed=None):
The matrix of random projections of the shape [m, d].
"""
nb_full_blocks = math.ceil(m / d)
block_list = tf.TensorArray(tf.float32,
size=tf.cast(nb_full_blocks, dtype=tf.int32))
block_list = tf.TensorArray(
tf.float32, size=tf.cast(nb_full_blocks, dtype=tf.int32))
stateful = False
if seed is None:
stateful = True
......@@ -85,11 +317,13 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, f, h):
def _generalized_kernel(x, y, is_query, projection_matrix, f, h):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
f: A non-linear function applied on x or projected x.
......@@ -99,7 +333,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns:
Transformed feature.
"""
del y
del is_query
if projection_matrix is None:
return h(x) * f(x)
else:
......@@ -108,8 +343,124 @@ def _generalized_kernel(x, projection_matrix, f, h):
tf.cast(tf.shape(projection_matrix)[0], tf.float32))
def expplus(data_orig,
other_data,
is_query,
projection_matrix=None,
numerical_stabilizer=0.000001,
normalize_data=True,
numerical_renormalizer=True,
extra_renormalize_exp_fun=False):
"""FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 .
Args:
data_orig: data tensor of shape [B,T,H,D] for which random features aree to
be computed
other_data: additional tensor of the shape [B,F,H,D] used to collect stats
to determine the exact instantiation of the random feature mechanism
is_query: boolean indicating whether <data_orig> tensor is a query tensor
projection_matrix: tensor of the shape [M,D] encoding random projections for
random features (M stands for the number of random features)
numerical_stabilizer: numerical stabilizer for the kernel features
normalize_data: whether to sqrt-d-normalize queries/keys as in the regular
attention
numerical_renormalizer: whether to apply additional renormalization for
numerical stability
extra_renormalize_exp_fun: extra renormalizer for the exponential mapping
applied to construct random features
Returns:
Random feature map tensor for the unbiased softmax-kernel estimation.
"""
data = data_orig
if projection_matrix is None:
return data_orig
projection_matrix = tf.cast(projection_matrix, data.dtype)
if normalize_data:
data_normalizer = 1.0 / tf.math.sqrt(
(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype))))
else:
data_normalizer = 1.0
lengths = tf.math.square(data)
lengths = tf.reduce_sum(lengths, axis=tf.keras.backend.ndim(data) - 1)
lengths = tf.expand_dims(lengths, axis=tf.keras.backend.ndim(data) - 1)
lengths = tf.math.sqrt(lengths)
data /= lengths
ratio = 1.0 / tf.math.sqrt(
tf.dtypes.cast(projection_matrix.shape[0], data.dtype))
data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data,
projection_matrix)
diag_data = tf.math.square(data)
diag_data = tf.math.reduce_sum(
diag_data, axis=tf.keras.backend.ndim(data) - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
# Calculating coefficients A, B of the FAVOR++ mechanism:
_, l, _, _ = tf_utils.get_shape_list(data_orig)
l = tf.cast(l, dtype=tf.float32)
first_sum_of_squares = tf.math.square(data)
first_sum_of_squares = tf.math.reduce_sum(
first_sum_of_squares, axis=(1, -1), keepdims=True)
first_sum_of_squares *= (data_normalizer * data_normalizer)
first_sum_of_squares /= l # data.shape[1]
second_sum_of_squares = tf.math.square(other_data)
second_sum_of_squares = tf.math.reduce_sum(
second_sum_of_squares, axis=(1, -1), keepdims=True)
second_sum_of_squares *= (data_normalizer * data_normalizer)
second_sum_of_squares /= l # other_data.shape[1]
data_sum = tf.math.reduce_sum(data, axis=(1,), keepdims=True)
other_data_sum = tf.math.reduce_sum(other_data, axis=(1,), keepdims=True)
d_prod = tf.einsum("blhd,blhd->blh", data_sum, other_data_sum)
d_prod = tf.expand_dims(d_prod, axis=-1)
d_prod *= (data_normalizer * data_normalizer)
d_prod *= (2.0 / (l * l))
ave = first_sum_of_squares + second_sum_of_squares + d_prod
dim = projection_matrix.shape[-1]
a_coeff = (1.0 / (4.0 * ave)) * (
tf.math.sqrt((2.0 * ave + dim) *
(2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
a_coeff = (1.0 - 1.0 / a_coeff) / 8.0
b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff)
d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0)
a_coeff = tf.stop_gradient(a_coeff)
b_coeff = tf.stop_gradient(b_coeff)
d_coeff = tf.stop_gradient(d_coeff)
# Calculating diag_omega for the FAVOR++ mechanism:
diag_omega = tf.math.square(projection_matrix)
diag_omega = tf.math.reduce_sum(
diag_omega, axis=tf.keras.backend.ndim(projection_matrix) - 1)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = a_coeff * diag_omega
if numerical_renormalizer:
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
stab = b_coeff * tf.math.reduce_max(
data_dash, axis=last_dims_t, keepdims=True)
else:
stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True)
if extra_renormalize_exp_fun:
extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
stab = tf.math.maximum(stab, extra_stab)
data_dash = ratio * d_coeff * (
tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) +
numerical_stabilizer)
else:
data_dash = ratio * d_coeff * (
tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) +
numerical_stabilizer)
return data_dash
# pylint: disable=g-long-lambda
_TRANSFORM_MAP = {
_CAUSAL_SUPPORT_TRANSFORM_MAP = {
"elu":
functools.partial(
_generalized_kernel,
......@@ -117,19 +468,22 @@ _TRANSFORM_MAP = {
h=lambda x: 1),
"relu":
functools.partial(
_generalized_kernel, f=tf.keras.activations.relu, h=lambda x: 1),
_generalized_kernel,
# Improve numerical stability and avoid NaNs in some cases by adding
# a tiny epsilon.
f=lambda x: tf.keras.activations.relu(x) + 1e-3,
h=lambda x: 1),
"square":
functools.partial(
_generalized_kernel, f=tf.math.square, h=lambda x: 1),
functools.partial(_generalized_kernel, f=tf.math.square, h=lambda x: 1),
"exp":
functools.partial(
_generalized_kernel,
# Avoid exp explosion by shifting.
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),),
f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),
),
"expmod":
functools.partial(
_generalized_kernel,
......@@ -142,6 +496,16 @@ _TRANSFORM_MAP = {
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP = {
"expplus": expplus,
}
_TRANSFORM_MAP = {
**_CAUSAL_SUPPORT_TRANSFORM_MAP,
**_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda
......@@ -154,6 +518,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
- random/deterministic projection
Chefs' Random Tables: Non-Trigonometric Random Features
(https://arxiv.org/abs/2205.15317)
- expplus (OPRF mechanism)
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
(https://arxiv.org/abs/2006.16236)
......@@ -178,13 +545,19 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq=False,
begin_kernel=0,
scale=None,
scale_by_length=False,
use_causal_windowed=False,
causal_chunk_length=1,
causal_window_length=3,
causal_window_decay=None,
causal_padding=None,
**kwargs):
r"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"identity".
feature_transform: A non-linear transform of the keys and queries.
Possible transforms are "elu", "relu", "square", "exp", "expplus",
"expmod", "identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
......@@ -194,12 +567,28 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
(default option).
very short sequences or not; in most cases this should be False (default
option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
use_causal_windowed: If true perform windowed causal attention. See
causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor before
summation.
causal_padding: Pad the query, value and key input tensors across the axis
from either left or right if padding is set to "left" or "right"; apply
no padding if padding is set to None. In the latter case, the axis
dimension of the query, value and key input tensors must be divisible by
the chunk_length.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if feature_transform not in _TRANSFORM_MAP:
......@@ -214,6 +603,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._redraw = redraw
self._is_short_seq = is_short_seq
self._begin_kernel = begin_kernel
self._scale_by_length = scale_by_length
# We use the seed for two scenarios:
# 1. inference
# 2. no redraw
......@@ -228,6 +618,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._projection_matrix = create_projection_matrix(
self._num_random_features, self._key_dim,
tf.constant([self._seed, self._seed + 1]))
self.use_causal_windowed = use_causal_windowed
self.causal_chunk_length = causal_chunk_length
self.causal_window_length = causal_window_length
self.causal_window_decay = causal_window_decay
self.causal_padding = causal_padding
if self.use_causal_windowed and self._is_short_seq:
raise ValueError(
"use_causal_windowed and short_seq methods are mutually exclusive")
def _compute_attention(self,
query,
......@@ -236,6 +634,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform,
is_short_seq,
attention_mask=None,
cache=None,
training=False,
numeric_stabler=_NUMERIC_STABLER):
"""Applies kernel attention with query, key, value tensors.
......@@ -252,9 +651,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
......@@ -263,6 +664,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix = None
if self._num_random_features > 0:
if self._redraw and training:
projection_matrix = create_projection_matrix(self._num_random_features,
......@@ -270,35 +672,53 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else:
projection_matrix = self._projection_matrix
if self._scale_by_length:
scale = tf.math.log(tf.reduce_sum(attention_mask,
axis=-1)) * self._scale / math.log(512)
scale = tf.reshape(scale, [-1, 1, 1, 1])
else:
scale = self._scale
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
query = query * scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key *= tf.math.sqrt(scale)
query *= tf.math.sqrt(scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
key_prime = _TRANSFORM_MAP[feature_transform](key, query, False,
projection_matrix)
query_prime = _TRANSFORM_MAP[feature_transform](query, key, True,
projection_matrix)
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
if is_short_seq:
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
elif self.use_causal_windowed:
attention_output = causal_windowed_performer_attention(
query_prime,
key_prime,
value,
chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length,
window_decay=self.causal_window_decay,
padding=self.causal_padding,
cache=cache)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
tf.einsum("BTNH,BNH->BTN", query_prime,
tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER)
attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv,
denominator)
return attention_output
def _build_from_signature(self, query, value, key=None):
......@@ -313,15 +733,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._output_dense_softmax = self._make_output_dense(
self._query_shape.rank - 1, common_kwargs,
self._query_shape.rank - 1,
common_kwargs,
name="attention_output_softmax")
self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)
def call(self,
query,
value,
key=None,
attention_mask=None,
def call(self, query, value, key=None, attention_mask=None, cache=None,
training=False):
"""Compute attention with kernel mechanism.
......@@ -330,15 +747,32 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if cache is not None:
if training:
raise ValueError(
"Cache is not supported when training is True.")
if not self.use_causal_windowed:
raise ValueError(
"Cache is not supported for non use_causal_windowed case.")
if self._begin_kernel:
raise ValueError(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated.")
if self._feature_transform in _NON_CAUSAL_SUPPORT_TRANSFORM_MAP:
raise ValueError("Cache is not supported for feature_transform %s" %
(self._feature_transform))
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
......@@ -357,25 +791,26 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
if self._begin_kernel > 0:
attention_output_softmax = self._compute_attention(
query[:, :self._begin_kernel],
key, value, "identity", True, attention_mask, training)
query[:, :self._begin_kernel], key, value, "identity", True,
attention_mask, training)
attention_output_softmax = self._dropout_softmax(attention_output_softmax)
attention_output_softmax = self._output_dense_softmax(
attention_output_softmax)
attention_output_kernel = self._compute_attention(
query[:, self._begin_kernel:],
key, value, self._feature_transform, self._is_short_seq,
attention_mask, training)
query[:, self._begin_kernel:], key, value, self._feature_transform,
self._is_short_seq, attention_mask, training)
attention_output_kernel = self._dropout_layer(attention_output_kernel)
attention_output_kernel = self._output_dense(
attention_output_kernel)
attention_output_kernel = self._output_dense(attention_output_kernel)
attention_output = tf.concat(
[attention_output_softmax, attention_output_kernel], axis=1)
else:
attention_output = self._compute_attention(
query, key, value, self._feature_transform,
self._is_short_seq, attention_mask, training)
attention_output = self._compute_attention(query, key, value,
self._feature_transform,
self._is_short_seq,
attention_mask,
cache,
training)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output = self._dropout_layer(attention_output)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_FEATURE_TRANSFORM = ["relu", "elu", "exp", "expplus"]
_REDRAW = [True, False]
_TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False]
......@@ -30,9 +30,67 @@ _BEGIN_KERNEL = [0, 512]
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@parameterized.parameters(itertools.product(
_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL))
["relu", "elu"], [1, 4], [0.9]))
def test_causal_windowed_attention_projection_streaming(
self, feature_transform, causal_chunk_length, causal_weight_decay):
num_heads = 12
key_dim = 64
seq_length = 16
num_chunks = seq_length // causal_chunk_length
causal_window_length = num_chunks
batch_size = 2
training = False
num_random_features = 0
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform=feature_transform,
num_random_features=num_random_features,
redraw=False,
is_short_seq=False,
begin_kernel=False,
use_causal_windowed=True,
causal_chunk_length=causal_chunk_length,
causal_window_length=causal_window_length,
causal_window_decay=causal_weight_decay,
causal_padding=None,
)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim), seed=2)
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output = test_layer(
query=query,
value=value,
attention_mask=masks,
training=training)
dim = num_random_features if num_random_features > 0 else key_dim
kv_cache = tf.zeros(
(batch_size, num_heads, dim, dim))
k_sum_cache = tf.zeros((batch_size, num_heads, dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
stream_output.append(
test_layer(
query=query[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
value=value[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
attention_mask=masks[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length],
cache=cache,
training=training))
stream_output = tf.concat(stream_output, axis=1)
self.assertAllClose(output, stream_output)
@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL))
def test_attention_projection(
self, feature_transform, num_random_features, training, redraw, is_short,
begin_kernel):
......@@ -60,6 +118,41 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters(
itertools.product(["relu", "exp"], [127], _TRAINING, [True, False],
[0], [None, 0.97], [None, "left", "right"]))
def test_causal_windowed_attention_projection(
self, feature_transform, num_random_features, training, redraw,
begin_kernel, causal_window_decay, causal_padding):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform=feature_transform,
num_random_features=num_random_features,
redraw=redraw,
is_short_seq=False,
begin_kernel=begin_kernel,
use_causal_windowed=True,
causal_chunk_length=8,
causal_window_length=3,
causal_window_decay=causal_window_decay,
causal_padding=causal_padding)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output = test_layer(
query=query,
value=value,
attention_mask=masks,
training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters(itertools.product(
_FEATURE_TRANSFORM, [0], _TRAINING, [False],
_IS_SHORT_SEQ, _BEGIN_KERNEL))
......@@ -90,15 +183,41 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters([128, 512])
def test_attention_scale_by_length(self, seq_length):
num_heads = 12
key_dim = 64
batch_size = 2
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
num_random_features=0,
scale_by_length=True)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
test_layer._scale_by_length = False
output_no_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
if seq_length == 512: # Equals because log(seq_length, base=512) = 1.0
self.assertAllClose(output_scale_by_length, output_no_scale_by_length)
else:
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
def test_unsupported_feature_transform(self):
with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'):
_ = attention.KernelAttention(feature_transform='test')
with self.assertRaisesRegex(ValueError, "Unsupported feature_transform.*"):
_ = attention.KernelAttention(feature_transform="test")
def test_redraw_true_no_projection(self):
with self.assertRaisesRegex(
ValueError, 'There is nothing to redraw when num_random_features.*'):
ValueError, "There is nothing to redraw when num_random_features.*"):
_ = attention.KernelAttention(
num_heads=2, key_dim=64, feature_transform='elu',
num_heads=2, key_dim=64, feature_transform="elu",
num_random_features=0, redraw=True)
def test_config(self):
......@@ -107,7 +226,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform='exp',
feature_transform="exp",
num_random_features=128,
is_short_seq=True)
new_layer = attention.KernelAttention.from_config(
......@@ -115,5 +234,25 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
def test_rectangular_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.rectangular_window_sum(x, 3)
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 2., 3., 3., 3.], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)
def test_weighted_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.weighted_window_sum(x, 3, [0.01, 0.1, 1.])
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 1.1, 1.11, 1.11, 1.11], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -47,7 +47,7 @@ class MaskedLM(tf.keras.layers.Layer):
output='logits',
name=None,
**kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
super().__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
......@@ -73,7 +73,7 @@ class MaskedLM(tf.keras.layers.Layer):
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
super().build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
......@@ -115,7 +115,8 @@ class MaskedLM(tf.keras.layers.Layer):
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_positions = tf.reshape(
positions + tf.cast(flat_offsets, positions.dtype), [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -53,7 +53,7 @@ class MaskedSoftmax(tf.keras.layers.Layer):
self._normalization_axes = (-1,)
else:
self._normalization_axes = normalization_axes
super(MaskedSoftmax, self).__init__(**kwargs)
super().__init__(**kwargs)
def call(self, scores, mask=None):
......@@ -81,5 +81,5 @@ class MaskedSoftmax(tf.keras.layers.Layer):
'mask_expansion_axes': self._mask_expansion_axes,
'normalization_axes': self._normalization_axes
}
base_config = super(MaskedSoftmax, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -36,7 +36,7 @@ class MatMulWithMargin(tf.keras.layers.Layer):
logit_scale=1.0,
logit_margin=0.0,
**kwargs):
super(MatMulWithMargin, self).__init__(**kwargs)
super().__init__(**kwargs)
self.logit_scale = logit_scale
self.logit_margin = logit_margin
......@@ -61,7 +61,7 @@ class MatMulWithMargin(tf.keras.layers.Layer):
config = {
'logit_scale': self.logit_scale,
'logit_margin': self.logit_margin}
config.update(super(MatMulWithMargin, self).get_config())
config.update(super().get_config())
return config
@classmethod
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based mixing layers.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
Mixing layers can be used as drop in replacements for self-attention layers. For
interoperability with attention layers, we use the same `query` and `value` call
signature.
Note: These mixing layers currently only support encoder stacks. Decoder stacks
can be supported in the future by utilizing the `value` inputs.
"""
import enum
import functools
from typing import Callable, Tuple, Union
import numpy as np
from scipy import linalg
import tensorflow as tf
from official.modeling import tf_utils
_Initializer = Union[str, tf.keras.initializers.Initializer]
default_kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=2e-2)
class MixingMechanism(enum.Enum):
"""Determines the type of mixing layer.
Possible options:
FOURIER: Fourier Transform mixing.
LINEAR: Mixing using dense matrix multiplications with learnable weights.
HARTLEY: Hartley Transform mixing.
"""
FOURIER = "fourier"
HARTLEY = "hartley"
LINEAR = "linear"
class MixingLayer(tf.keras.layers.Layer):
"""Mixing layer base class.
This class cannot be used directly. It just specifies the API for mixing
layer subclasses. For interoperability with attention layers, we use the same
`query` and `value` call signature.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
"""
def __init__(self, name: str = "mixing", **kwargs):
"""Initializes layer.
Args:
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Calls the layer.
Subclasses should return tensors of shape
<float>[batch_size, max_seq_length, hidden_dim].
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Raises:
NotImplementedError. This class should not be called directly.
"""
raise NotImplementedError("Abstract method")
class FourierTransformLayer(MixingLayer):
"""Fourier Transform layer.
Applies 2D Fourier Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def __init__(self,
use_fft: bool = False,
name: str = "fourier_transform",
**kwargs):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Fourier Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.use_fft = use_fft
def build(self, input_shape: Tuple[int, ...]):
"""Picks the Fourier Transform implementation."""
self.fourier_transform = _pick_fourier_transform(
self.use_fft,
max_seq_length=input_shape[-2],
hidden_dim=input_shape[-1])
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Fourier Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
query = tf.cast(query, tf.complex64)
return tf.math.real(self.fourier_transform(query))
class HartleyTransformLayer(MixingLayer):
"""Hartley Transform layer.
Applies 2D Hartley Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def __init__(self,
use_fft: bool = False,
name: str = "hartley_transform",
**kwargs):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Hartley Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.use_fft = use_fft
def build(self, input_shape: Tuple[int, ...]):
"""Picks the Fourier Transform implementation."""
self.fourier_transform = _pick_fourier_transform(
self.use_fft,
max_seq_length=input_shape[-2],
hidden_dim=input_shape[-1])
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Hartley Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
query = tf.cast(query, tf.complex64)
frequencies = self.fourier_transform(query)
return tf.math.real(frequencies) - tf.math.imag(frequencies)
class LinearTransformLayer(MixingLayer):
"""Dense, linear transformation layer.
Applies matrix multiplications over sequence and hidden dimensions.
"""
def __init__(self,
kernel_initializer: _Initializer = default_kernel_initializer,
name: str = "linear_transform",
**kwargs):
"""Initializes layer.
Args:
kernel_initializer: Initialization scheme for kernel.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.kernel_initializer = kernel_initializer
def build(self, input_shape: Tuple[int, ...]):
"""Creates the hidden and sequence matrix variables of the layer."""
self.mat_hidden = self.add_weight(
shape=(input_shape[-1], input_shape[-1]),
initializer=tf_utils.clone_initializer(self.kernel_initializer),
trainable=True,
name="hidden_kernel")
self.mat_seq = self.add_weight(
shape=(input_shape[-2], input_shape[-2]),
initializer=tf_utils.clone_initializer(self.kernel_initializer),
trainable=True,
name="seq_kernel")
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Linearly transformed `query` inputs with shape
<float>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
return tf.einsum("bij,jk,ni->bnk", query, self.mat_hidden, self.mat_seq)
def _pick_fourier_transform(
use_fft: bool, max_seq_length: int,
hidden_dim: int) -> Callable[[tf.Tensor], tf.Tensor]:
"""Returns FFT or DFT Fourier Transform implementation.
On TPUs, we recommend using the Discrete Fourier Transform (DFT) matrix
(use_fft=False), except for very long sequence lengths. On GPUs and CPUs, the
Fast Fourier Transform (use_fft=True) is generally optimal for all sequence
lengths.
Note: When using the FFT it is recommended to use a sequence length that is a
power of 2.
Args:
use_fft: If True, return FFT. Otherwise, return DFT matrix.
max_seq_length: Maximum sequence length of inputs. Only used if
use_fft=False.
hidden_dim: Size of hidden dimension of inputs. Only used if use_fft=False.
Returns:
Fourier Transform.
"""
if use_fft:
return tf.signal.fft2d
else:
dft_mat_seq = linalg.dft(max_seq_length).astype(np.complex64)
dft_mat_hidden = linalg.dft(hidden_dim).astype(np.complex64)
def two_dim_matmul(x: tf.Tensor, matrix_dim_one: tf.Tensor,
matrix_dim_two: tf.Tensor) -> tf.Tensor:
"""Applies 2D matrix multiplication to input tensors of rank >= 2."""
return tf.einsum("...ij,jk,ni->...nk", tf.cast(x, tf.complex64),
matrix_dim_two, matrix_dim_one)
return functools.partial(
two_dim_matmul,
matrix_dim_one=tf.convert_to_tensor(dft_mat_seq),
matrix_dim_two=tf.convert_to_tensor(dft_mat_hidden))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mixing.py."""
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import mixing
class MixingTest(tf.test.TestCase):
def test_base_mixing_layer(self):
inputs = tf.random.uniform((3, 8, 16),
minval=0,
maxval=10,
dtype=tf.float32)
with self.assertRaisesRegex(NotImplementedError, "Abstract method"):
_ = mixing.MixingLayer()(query=inputs, value=inputs)
def test_fourier_layer(self):
batch_size = 4
max_seq_length = 8
hidden_dim = 16
inputs = tf.random.uniform((batch_size, max_seq_length, hidden_dim),
minval=0,
maxval=10,
dtype=tf.float32)
outputs = mixing.FourierTransformLayer(use_fft=True)(
query=inputs, value=inputs)
self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_hartley_layer(self):
batch_size = 3
max_seq_length = 16
hidden_dim = 4
inputs = tf.random.uniform((batch_size, max_seq_length, hidden_dim),
minval=0,
maxval=12,
dtype=tf.float32)
outputs = mixing.HartleyTransformLayer(use_fft=True)(
query=inputs, value=inputs)
self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_linear_mixing_layer(self):
batch_size = 2
max_seq_length = 4
hidden_dim = 3
inputs = tf.ones((batch_size, max_seq_length, hidden_dim), dtype=tf.float32)
outputs = mixing.LinearTransformLayer(
kernel_initializer=tf.keras.initializers.Ones())(
query=inputs, value=inputs)
# hidden_dim * (max_seq_length * 1) = 12.
expected_outputs = [
[
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
],
[
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
],
]
np.testing.assert_allclose(outputs, expected_outputs, rtol=1e-6, atol=1e-6)
def test_pick_fourier_transform(self):
# Ensure we don't hit an edge case which exceeds the fixed numerical error.
tf.random.set_seed(1)
np.random.seed(1)
batch_size = 3
max_seq_length = 4
hidden_dim = 8
fft = mixing._pick_fourier_transform(
use_fft=True, max_seq_length=max_seq_length, hidden_dim=hidden_dim)
dft_matmul = mixing._pick_fourier_transform(
use_fft=False, max_seq_length=max_seq_length, hidden_dim=hidden_dim)
inputs = tf.random.uniform([batch_size, max_seq_length, hidden_dim])
inputs = tf.cast(inputs, tf.complex64)
np.testing.assert_allclose(
fft(inputs), dft_matmul(inputs), rtol=1e-6, atol=1e-6)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,6 +15,8 @@
"""MobileBERT embedding and transformer layers."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import on_device_embedding
from official.nlp.modeling.layers import position_embedding
......@@ -24,7 +26,7 @@ class NoNorm(tf.keras.layers.Layer):
"""Apply element-wise linear transformation to the last dimension."""
def __init__(self, name=None):
super(NoNorm, self).__init__(name=name)
super().__init__(name=name)
def build(self, shape):
kernal_size = shape[-1]
......@@ -96,7 +98,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
dropout_rate: Dropout rate.
**kwargs: keyword arguments.
"""
super(MobileBertEmbedding, self).__init__(**kwargs)
super().__init__(**kwargs)
self.word_vocab_size = word_vocab_size
self.word_embed_size = word_embed_size
self.type_vocab_size = type_vocab_size
......@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.word_embedding = on_device_embedding.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='word_embedding')
self.type_embedding = on_device_embedding.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='type_embedding')
self.pos_embedding = position_embedding.PositionEmbedding(
max_length=max_sequence_length,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
self.word_embedding_proj = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.output_embed_size],
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
bias_axes='d',
name='embedding_projection')
self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
......@@ -220,7 +222,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
super(MobileBertTransformer, self).__init__(**kwargs)
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
......@@ -242,11 +244,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.block_layers = {}
# add input bottleneck
dense_layer_2d = tf.keras.layers.experimental.EinsumDense(
dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_input/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_input/norm')
......@@ -254,11 +256,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm]
if self.key_query_shared_bottleneck:
dense_layer_2d = tf.keras.layers.experimental.EinsumDense(
dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='kq_shared_bottleneck/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='kq_shared_bottleneck/norm')
......@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
value_dim=attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.intra_bottleneck_size,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='attention')
layer_norm = _get_norm_layer(self.normalization_type,
name='attention/norm')
......@@ -284,19 +286,19 @@ class MobileBertTransformer(tf.keras.layers.Layer):
for ffn_layer_idx in range(self.num_feedforward_networks):
layer_prefix = f'ffn_layer_{ffn_layer_idx}'
layer_name = layer_prefix + '/intermediate_dense'
intermediate_layer = tf.keras.layers.experimental.EinsumDense(
intermediate_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd',
activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense(
output_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/norm'
layer_norm = _get_norm_layer(self.normalization_type,
......@@ -306,12 +308,12 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm])
# add output bottleneck
bottleneck = tf.keras.layers.experimental.EinsumDense(
bottleneck = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.hidden_size],
activation=None,
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_output/dense')
dropout_layer = tf.keras.layers.Dropout(
self.hidden_dropout_prob,
......@@ -445,6 +447,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
activation=None,
initializer='glorot_uniform',
output='logits',
output_weights_use_proj=False,
**kwargs):
"""Class initialization.
......@@ -455,9 +458,12 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
uniform initializer.
output: The output style for this layer. Can be either `logits` or
`predictions`.
output_weights_use_proj: Use projection instead of concating extra output
weights, this may reduce the MLM task accuracy but will reduce the model
params as well.
**kwargs: keyword arguments.
"""
super(MobileBertMaskedLM, self).__init__(**kwargs)
super().__init__(**kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
......@@ -467,6 +473,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
self._output_weights_use_proj = output_weights_use_proj
def build(self, input_shape):
self._vocab_size, embedding_width = self.embedding_table.shape
......@@ -474,15 +481,22 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/dense')
if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=self.initializer,
trainable=True)
if self._output_weights_use_proj:
self.extra_output_weights = self.add_weight(
'output_weights_proj',
shape=(embedding_width, hidden_size),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
else:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
else:
......@@ -507,10 +521,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
if self.extra_output_weights is None:
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights], axis=1),
transpose_b=True)
if self._output_weights_use_proj:
lm_data = tf.matmul(
lm_data, self.extra_output_weights, transpose_b=True)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights],
axis=1),
transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mixture of Experts layers and their routing mechanisms."""
import dataclasses
from typing import Any, Callable, Optional, Tuple
from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
_InitializerType = tf.keras.initializers.Initializer
_DEFAULT_KERNEL_INITIALIZER = tf.keras.initializers.TruncatedNormal(stddev=2e-2)
_DEFAULT_BIAS_INITIALIZER = tf.keras.initializers.Zeros()
################## Routers (gating functions) ##################
def _router_z_loss(router_logits: tf.Tensor) -> float:
"""Computes router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float32>[num_groups, tokens_per_group, num_experts] router
logits.
Returns:
Scalar router z-loss <float32>.
"""
num_groups, tokens_per_group, _ = router_logits.shape
log_z = tf.math.reduce_logsumexp(router_logits, axis=-1)
z_loss = log_z**2
return tf.math.reduce_sum(z_loss) / (num_groups * tokens_per_group)
@dataclasses.dataclass
class RouterMask:
"""Dispatch and combine arrays for expert routing with masked matmuls.
Attributes:
dispatch_mask:
<float>[num_groups, tokens_per_group, num_experts, expert_capacity]
dispatch array that is 1 if the token gets routed to the
corresponding expert, and 0 otherwise.
combine_array:
<float>[num_groups, tokens_per_group, num_experts, expert_capacity]
combine array used for combining expert outputs and
scaling with router probability.
"""
dispatch_mask: tf.Tensor
combine_array: tf.Tensor
RouterOutput = RouterMask
class Router(tf.keras.layers.Layer):
"""Abstract base router class, defining router API and inner workings.
Computations are performed in float32 for stability, and returned after
conversion according to the precision policy. See the discussion of
"selective precision" in https://arxiv.org/abs/2101.03961.
Uses Keras add_loss() and add_metric() APIs.
Attributes:
num_experts: Number of experts, used to check consistency with
FeedForwardExperts.
jitter_noise: Amplitude of jitter noise applied to router logits.
router_weights: Dense layer that computes logits for all tokens, which are
then used as expert or token weights.
"""
def __init__(
self,
num_experts: int,
*,
jitter_noise: float = 0.0,
use_bias: bool = True,
kernel_initializer: _InitializerType = _DEFAULT_KERNEL_INITIALIZER,
bias_initializer: _InitializerType = _DEFAULT_BIAS_INITIALIZER,
name: str = "router",
dtype: Any = tf.float32,
**kwargs):
"""Init.
Args:
num_experts: Number of experts.
jitter_noise: Amplitude of jitter noise applied to router logits.
use_bias: Whether or not to use the bias term in computing the router
weights.
kernel_initializer: Kernel initializer for router weights.
bias_initializer: Bias initializer for router weights.
name: Layer name.
dtype: The dtype of the layer's computations and weights. tf.float32 is
recommended for stability.
**kwargs: Forwarded to super.
"""
super().__init__(name=name, dtype=dtype, **kwargs)
self.num_experts = num_experts # Used to check consistency with
# FeedForwardExperts.
self.jitter_noise = jitter_noise
self.router_weights = tf.keras.layers.Dense(
num_experts,
use_bias=use_bias,
kernel_initializer=tf_utils.clone_initializer(kernel_initializer),
bias_initializer=tf_utils.clone_initializer(bias_initializer),
name="router_weights",
dtype=dtype)
def call(self,
inputs: tf.Tensor,
*,
expert_capacity: int,
training: Optional[bool] = None) -> RouterOutput:
"""Computes dispatch and combine arrays for routing to experts.
Args:
inputs: Inputs to send to experts of shape
<float>[num_groups, tokens_per_group, hidden_dim].
expert_capacity: Each group will send this many tokens to each expert.
training: If true, apply jitter noise during routing. If not provided
taken from tf.keras.backend.
Returns:
Router indices or mask arrays (depending on router type).
"""
if training is None:
training = tf.keras.backend.learning_phase()
# inputs shape <float>[num_groups, tokens_per_group, hidden_dim]
router_probs, router_logits = self._compute_router_probabilities(
inputs, apply_jitter=training)
# router_probs <float32>[num_groups, tokens_per_group, num_experts]
# router_logits <float>[num_groups, tokens_per_group, num_experts]
router_z_loss = _router_z_loss(router_logits)
self.add_loss(router_z_loss)
self.add_metric(router_z_loss, name="router_z_loss")
routing_instructions = self._compute_routing_instructions(
router_probs, expert_capacity)
return routing_instructions
def _compute_router_probabilities(
self, inputs: tf.Tensor,
apply_jitter: bool) -> Tuple[tf.Tensor, tf.Tensor]:
"""Computes router probabilities from input tokens.
Args:
inputs: Inputs from which router probabilities are computed, shape
<float>[num_groups, tokens_per_group, hidden_dim].
apply_jitter: If true, apply jitter noise.
Returns:
- <float32>[num_groups, tokens_per_group, num_experts] probabilities for
each token and expert. Used for routing tokens to experts.
- <float32>[num_groups, tokens_per_group, num_experts] raw router logits.
Used for computing router z-loss.
"""
if apply_jitter and self.jitter_noise > 0:
inputs *= tf.random.uniform(
inputs.shape,
minval=1.0 - self.jitter_noise,
maxval=1.0 + self.jitter_noise,
dtype=inputs.dtype)
# inputs <float>, router_logits <float32>
router_logits = self.router_weights(inputs)
router_probs = tf.keras.activations.softmax(router_logits, axis=-1)
return router_probs, router_logits
def _compute_routing_instructions(self, router_probs: tf.Tensor,
expert_capacity: int) -> RouterOutput:
"""Computes instructions for routing inputs to experts."""
raise NotImplementedError(
"Router is an abstract class that should be subclassed.")
class MaskedRouter(Router):
"""Abstract base router class for masked matmul dispatch routers.
MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine
array for sending and receiving (via masked matmuls) inputs and outputs to and
from experts.
Routing using masked matmuls is generally faster than scatter-based routing on
TPUs.
Uses Keras add_loss() and add_metric() APIs.
"""
def _compute_routing_instructions(self, router_probs: tf.Tensor,
expert_capacity: int) -> RouterMask:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
expert_capacity: Each group will send this many tokens to each expert.
Returns:
Router mask arrays.
"""
raise NotImplementedError(
"MaskedRouter is an abstract class that should be subclassed.")
class ExpertsChooseMaskedRouter(MaskedRouter):
"""Masked matmul router using experts choose tokens assignment.
This router uses the same mechanism as in Mixture-of-Experts with Expert
Choice (https://arxiv.org/abs/2202.09368): each expert selects its top
expert_capacity tokens. An individual token may be processed by multiple
experts or none at all.
Note: "experts choose routing" should not be used in decoder blocks because it
breaks the autoregressive behavior, leading to a mismatch between training
(teacher forcing) and inference (autoregressive decoding).
Uses Keras add_loss() and add_metric() APIs.
"""
def _compute_routing_instructions(self, router_probs: tf.Tensor,
expert_capacity: int) -> RouterMask:
"""Computes masks for the highest probability token per expert.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
expert_capacity: Each group will send this many tokens to each expert.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
num_groups, tokens_per_group, _ = router_probs.shape
router_probs_t = tf.transpose(router_probs, perm=[0, 2, 1])
# router_probs_t: <float32>[num_groups, num_experts, tokens_per_group]
# Top expert_capacity router probability and corresponding token indices for
# each expert.
# Shapes [num_groups, num_experts, expert_capacity]
expert_gate, expert_index = tf.math.top_k(
router_probs_t, k=expert_capacity, sorted=False)
# Convert to one-hot mask of expert indices for each token in each group.
# Shape: [num_groups, num_experts, expert_capacity, tokens_per_group].
dispatch_mask = tf.one_hot(
expert_index, tokens_per_group, dtype=router_probs.dtype)
# Move axes to conform with shape expected by MoeLayer API.
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]
dispatch_mask = tf.transpose(dispatch_mask, perm=[0, 3, 1, 2])
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities.
# Shape: [num_groups, num_experts, tokens_per_group, expert_capacity]
combine_array = tf.einsum(
"...ec,...tec->...tec",
expert_gate,
dispatch_mask)
# Add load balancing loss.
# Each expert is choosing tokens until it reaches full capacity, so we don't
# need an auxiliary loading balancing loss for expert choice routing.
self.add_metric(0.0, name="load_balancing_loss")
# Gather expert metrics.
# Number of tokens that were dispatched to at least one expert.
num_tokens = num_groups * tokens_per_group
num_tokens_dispatched_somewhere = tf.math.reduce_sum(tf.math.reduce_max(
dispatch_mask, axis=(-1, -2)))
fraction_tokens_left_behind = 1.0 - num_tokens_dispatched_somewhere / float(
num_tokens)
# Total number of tokens that were dispatched (one token could be
# dispatched to multiple experts).
num_tokens_dispatched = tf.math.reduce_sum(dispatch_mask)
# Of the tokens dispatched, how confident was the router in its routing?
router_confidence = tf.math.reduce_sum(
combine_array) / num_tokens_dispatched
expert_usage = 1.0 # Experts fully utilized when "expert choose tokens"
self.add_metric(fraction_tokens_left_behind,
name="fraction_tokens_left_behind")
self.add_metric(router_confidence, name="router_confidence")
self.add_metric(expert_usage, name="expert_usage")
# Return to default dtype now that router computation is complete.
dtype = tf.keras.mixed_precision.global_policy().compute_dtype
dispatch_mask = tf.cast(dispatch_mask, dtype)
combine_array = tf.cast(combine_array, dtype)
output = RouterMask(dispatch_mask, combine_array)
return output
################## Model layers ##################
class FeedForward(tf.keras.layers.Layer):
"""Feed-forward layer - position independent, dense, nonlinear transformation.
Typically used in an MLP Transformer block.
"""
def __init__(
self,
d_ff: int,
*,
dropout_rate: float = 0.1,
activation: Callable[[tf.Tensor],
tf.Tensor] = tf.keras.activations.gelu,
kernel_initializer: _InitializerType = _DEFAULT_KERNEL_INITIALIZER,
bias_initializer: _InitializerType = _DEFAULT_BIAS_INITIALIZER,
name: str = "feed_forward",
**kwargs):
"""Initializes layer.
Args:
d_ff: Dimension of feed-forward layer.
dropout_rate: The dropout probability.
activation: (Nonlinear) transform applied in layer.
kernel_initializer: Initialization scheme for kernel.
bias_initializer: Initialization scheme for bias.
name: Layer name.
**kwargs: Forwarded to super.
"""
super().__init__(name=name, **kwargs)
self.activation = activation
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.intermediate_layer = tf.keras.layers.Dense(
d_ff,
kernel_initializer=tf_utils.clone_initializer(self.kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self.bias_initializer),
name="intermediate")
self.dropout_layer = tf.keras.layers.Dropout(dropout_rate)
def build(self, input_shape: Tuple[int, int, int]):
"""Creates the input shape dependent output weight variables."""
self.output_layer = tf.keras.layers.Dense(
input_shape[-1],
kernel_initializer=tf_utils.clone_initializer(self.kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self.bias_initializer),
name="output")
def call(self,
inputs: tf.Tensor,
*,
training: Optional[bool] = None) -> tf.Tensor:
"""Applies layer to inputs.
Args:
inputs: Batch of input embeddings, of shape
<float>[batch_size, seq_len, hidden_dim].
training: Only apply dropout during training.
Returns:
Transformed inputs with the same shape as inputs
<float>[batch_size, seq_len, hidden_dim].
"""
x = self.intermediate_layer(inputs)
x = self.activation(x)
x = self.output_layer(x)
x = self.dropout_layer(x, training=training)
return x
class FeedForwardExperts(tf.keras.layers.Layer):
"""Feed-forward layer with multiple experts.
Note that call() takes inputs with shape
[num_groups, num_experts, expert_capacity, hidden_dim]
which is different from the usual [batch_size, seq_len, hidden_dim] used by
the FeedForward layer.
The experts are independent FeedForward layers of the
same shape, i.e. the kernel doesn't have shape [hidden_dim, out_dim], but
[num_experts, hidden_dim, out_dim].
"""
def __init__(
self,
num_experts: int,
d_ff: int,
*,
dropout_rate: float = 0.1,
activation: Callable[[tf.Tensor],
tf.Tensor] = tf.keras.activations.gelu,
kernel_initializer: _InitializerType = _DEFAULT_KERNEL_INITIALIZER,
bias_initializer: _InitializerType = _DEFAULT_BIAS_INITIALIZER,
name: str = "experts",
**kwargs):
"""Initializes layer.
Args:
num_experts: Number of experts (i.e. number of independent feed-forward
blocks).
d_ff: Dimension of feed-forward layer of each expert.
dropout_rate: The dropout probability (expert_dropout_rate).
activation: (Nonlinear) transform applied in layer.
kernel_initializer: Initialization scheme for kernel.
bias_initializer: Initialization scheme for bias.
name: Layer name.
**kwargs: Forwarded to super.
"""
super().__init__(name=name, **kwargs)
self.num_experts = num_experts
self.activation = activation
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.intermediate_layer = tf.keras.layers.EinsumDense(
"gech,ehf->gecf",
output_shape=(self.num_experts, None, d_ff),
bias_axes="ef",
kernel_initializer=tf_utils.clone_initializer(self.kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self.bias_initializer),
name="intermediate")
self.dropout_layer = tf.keras.layers.Dropout(dropout_rate)
def build(self, input_shape: Tuple[int, int, int, int]):
"""Creates the input shape dependent output weight variables."""
if input_shape[1] != self.num_experts:
raise ValueError(
f"Input shape {input_shape} is inconsistent with num_experts "
f"{self.num_experts}.")
self.output_layer = tf.keras.layers.EinsumDense(
"gecf,efh->gech",
output_shape=(self.num_experts, None, input_shape[-1]),
bias_axes="eh",
kernel_initializer=tf_utils.clone_initializer(self.kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self.bias_initializer),
name="output")
def call(self,
inputs: tf.Tensor,
*,
training: Optional[bool] = None) -> tf.Tensor:
"""Applies layer to inputs.
Args:
inputs: Inputs of shape
<float>[num_groups, num_experts, expert_capacity, hidden_dim].
training: Only apply dropout during training.
Returns:
Transformed inputs with the same shape as inputs
<float>[num_groups, num_experts, expert_capacity, hidden_dim].
"""
x = self.intermediate_layer(inputs)
x = self.activation(x)
x = self.output_layer(x)
x = self.dropout_layer(x, training=training)
return x
class MoeLayer(tf.keras.layers.Layer):
"""Sparse MoE layer with per-token routing.
In this TF implementation, all experts need to fit onto a single device
allowing for batch parallelism only.
Uses Keras add_loss() and add_metric() APIs.
Attributes:
num_experts: Number of experts (i.e. number of independent feed-forward
blocks).
"""
def __init__(
self,
experts: FeedForwardExperts,
router: MaskedRouter,
*,
train_capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_expert_capacity: int = 4,
max_group_size: int = 4096,
strict_group_size: bool = False,
name: str = "moe",
**kwargs):
"""Init.
Args:
experts: Instance of FeedForwardExperts. Needs to have the same
num_experts as the router.
router: Instance of MaskedRouter to route the tokens to
the different experts.
train_capacity_factor: Scaling factor to increase the expert token
capacity during training. This factor plays an analogous, but slightly
different, role depending on the routing assignment algorithm:
- For "tokens choose" routing, the capacity factor only affects the
maximum number of tokens that an expert will process. It does not
affect how many experts a given token is routed to; see the
num_selected_experts attributes of "tokens choose" routers.
- For "experts choose" routing, because experts always fill their
buffer, increasing the capacity factor will increase the number of
tokens that an expert will process AND will indirectly increase the
number of experts that a given token is routed to.
eval_capacity_factor: As above, but used during evaluation.
min_expert_capacity: Minimum token processing capacity for each expert.
max_group_size: The total number of tokens on each device is subdivided
into groups of this size. Router computations are then performed on a
per-group basis. A larger group size will result in slower but more
accurate top-k and sorting computations, whereas a smaller group size
will result in faster but more approximate (and potentially less stable)
routing choices. Note that actual group size may be smaller than
max_group_size for consistency with the number of experts and tokens;
see also `strict_group_size` attribute. In practice,
we find that imperfect routing choices are tolerable and recommend
choosing a group size on the order of 4096 tokens, although this number
will vary based on model configuration and size.
strict_group_size: If True, fail if unable to set the token group size
equal to max_group_size. If False (default), the actual group size may
be smaller than max_group_size for consistency with the number of
experts and tokens.
name: Layer name.
**kwargs: Forwarded to super.
"""
super().__init__(name=name, **kwargs)
self._experts = experts
self._router = router
self.num_experts = experts.num_experts
assert experts.num_experts == router.num_experts
self._train_capacity_factor = train_capacity_factor
self._eval_capacity_factor = eval_capacity_factor
self._max_group_size = max_group_size
self._min_expert_capacity = min_expert_capacity
self._strict_group_size = strict_group_size
def call(self,
inputs: tf.Tensor,
*,
training: Optional[bool] = None) -> tf.Tensor:
"""Applies MoeLayer.
Args:
inputs: Batch of input embeddings of shape
<float>[batch_size, seq_length, hidden_dim].
training: Only apply dropout and jitter noise during training. If not
provided taken from tf.keras.backend.
Returns:
Transformed inputs with same shape as inputs:
<float>[batch_size, seq_length, hidden_dim].
Raises:
ValueError if we cannot find a group_size satisfying given requirements.
"""
if training is None:
training = tf.keras.backend.learning_phase()
# inputs shape [batch_size, seq_length, hidden_dim]
per_device_batch_size, seq_length, hidden_dim = inputs.shape
num_tokens = per_device_batch_size * seq_length
num_groups = self._num_groups(num_tokens, self._max_group_size)
tokens_per_group = num_tokens // num_groups
if training:
capacity_factor = self._train_capacity_factor
else:
capacity_factor = self._eval_capacity_factor
# Each group will send expert_capacity tokens to each expert.
expert_capacity = int(
round(capacity_factor * tokens_per_group / self.num_experts))
expert_capacity = max(expert_capacity, self._min_expert_capacity)
logging.info(
"Selected expert_capacity=%d for num_experts=%d and training=%r.",
expert_capacity, self.num_experts, training)
# Reshape batch and sequence/token dimensions for expert routing.
x = tf.reshape(inputs, (num_groups, tokens_per_group, hidden_dim))
x = self._mask_and_dispatch_to_experts(x, expert_capacity, training)
# Return to original input shape.
x = tf.reshape(x, (per_device_batch_size, seq_length, hidden_dim))
return x
def _num_groups(self, num_tokens: int, max_group_size: int) -> int:
"""Returns the number of token routing groups.
Note that the quantities are local to the device.
We select the smallest num_groups such that:
- num_groups >= num_tokens / max_group_size (ensuring the group size is no
larger than max_group_size),
- num_tokens % num_groups = 0 (ensuring that the group size evenly divides
into the num_tokens),
Args:
num_tokens: Number of tokens from input batch.
max_group_size: Maximum size of each token routing group. Actual group
size may end up being smaller unless strict_group_size==True.
Returns:
Number of token routing groups.
Raises:
ValueError if we cannot find a group_size satisfying the above
requirements.
"""
# Increase the number of groups (and decrease the group size) until we have
# a viable number of groups.
min_num_groups = int(np.ceil(num_tokens / max_group_size))
num_groups = min_num_groups
while num_groups < num_tokens and num_tokens % num_groups != 0:
num_groups += 1
group_size = num_tokens // num_groups
logging.info(
"Selected group_size=%d and num_groups=%d for input num_tokens=%d, "
"max_group_size=%d, num_experts=%d.",
group_size, num_groups, num_tokens, max_group_size, self.num_experts)
if group_size < self._min_expert_capacity:
raise ValueError(
f"Local (per-device) group_size {group_size} is smaller than "
f"min_expert_capacity {self._min_expert_capacity}, which is probably "
"not intended. Please increase max_group_size {max_group_size} to"
" seq_length or increase batch_size or decrease min_expert_capacity.")
if self._strict_group_size and group_size != self._max_group_size:
raise ValueError(
f"Selected group_size={group_size} is less than the "
f"max_group_size={max_group_size}. Exiting because strict mode is "
"active (strict_group_size=True)")
return num_groups
def _mask_and_dispatch_to_experts(self, inputs: tf.Tensor,
expert_capacity: int,
training: bool) -> tf.Tensor:
"""Wraps expert masked routing and dispatching algorithm.
This algorithm takes the following steps:
(1) Compute dispatch mask and combine array using self._router.
(2) Dispatch inputs to experts based on dispatch mask.
(3) Recombine individual expert outputs using combine array.
Args:
inputs: <float>[num_groups, tokens_per_group, hidden_dim] inputs to
send to experts.
expert_capacity: Each group will send this many tokens to each expert.
training: If true, apply jitter noise during routing and dropout
during expert computation.
Returns:
<float>[num_groups, num_tokens_per_group, hidden_dim] outputs from
experts.
"""
# Shape [num_groups, tokens_per_group, num_experts, expert_capacity]
router_mask = self._router(
inputs,
expert_capacity=expert_capacity,
training=training)
# Shape [num_groups, num_experts, expert_capacity, hidden_dim]
expert_inputs = tf.einsum(
"gth,gtec->gech",
inputs,
router_mask.dispatch_mask)
expert_outputs = self._experts(expert_inputs, training=training)
# Shape [num_groups, tokens_per_group, hidden_dim]
combined_outputs = tf.einsum(
"gech,gtec->gth",
expert_outputs,
router_mask.combine_array)
return combined_outputs
class MoeLayerWithBackbone(tf.keras.layers.Layer):
"""Sparse MoE layer plus a FeedForward layer evaluated for all tokens.
Uses Keras add_loss() and add_metric() APIs.
"""
def __init__(
self,
moe: MoeLayer,
backbone_d_ff: int,
*,
dropout_rate: float = 0.1,
activation: Callable[[tf.Tensor],
tf.Tensor] = tf.keras.activations.gelu,
kernel_initializer: _InitializerType = _DEFAULT_KERNEL_INITIALIZER,
bias_initializer: _InitializerType = _DEFAULT_BIAS_INITIALIZER,
name: str = "moe_with_backbone",
**kwargs):
"""Init.
Args:
moe: Instance of MoeLayer with experts and router.
backbone_d_ff: Dimension of feed-forward layer of a lightweight backbone,
which is evaluated for all tokens.
dropout_rate: Dropout rate for the backbone.
activation: (Nonlinear) transform applied in the backbone.
kernel_initializer: Initialization scheme for kernels in the backbone.
bias_initializer: Initialization scheme for biases in the backbone.
name: Layer name.
**kwargs: Forwarded to super.
"""
super().__init__(name=name, **kwargs)
self._moe = moe
self._backbone = FeedForward(
backbone_d_ff,
dropout_rate=dropout_rate,
activation=activation,
kernel_initializer=tf_utils.clone_initializer(kernel_initializer),
bias_initializer=tf_utils.clone_initializer(bias_initializer),
name="backbone")
def call(self,
inputs: tf.Tensor,
*,
training: Optional[bool] = None) -> tf.Tensor:
"""Applies MoeLayerWithBackbone layer.
Args:
inputs: Batch of input embeddings of shape
<float>[batch_size, seq_length, hidden_dim].
training: Only apply dropout and jitter noise during training. If not
provided taken from tf.keras.backend.
Returns:
Transformed inputs with same shape as inputs:
<float>[batch_size, seq_length, hidden_dim].
"""
return self._backbone(
inputs, training=training) + self._moe(
inputs, training=training)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for moe.py."""
import ml_collections
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import moe
def small_config() -> ml_collections.ConfigDict:
"""Creates a small model config that can be used by all tests."""
config = ml_collections.ConfigDict()
config.d_ff = 32
config.dropout_rate = 0.1
config.num_experts = 2
config.expert_d_ff = 33
config.expert_dropout_rate = 0.1
config.jitter_noise = 0.1
config.train_capacity_factor = 1.0
config.eval_capacity_factor = 1.0
config.min_expert_capacity = 1
config.max_group_size = 9
config.backbone_d_ff = 13
return config
def make_input_ones(batch_size: int = 2,
seq_length: int = 10,
hidden_dim: int = 7) -> tf.Tensor:
return tf.ones((batch_size, seq_length, hidden_dim), dtype=tf.float32)
def make_experts_input_ones(num_groups: int = 1,
num_experts: int = 2,
expert_capacity: int = 5,
hidden_dim: int = 7) -> tf.Tensor:
return tf.ones((num_groups, num_experts, expert_capacity, hidden_dim),
dtype=tf.float32)
class MoeTest(tf.test.TestCase):
def tearDown(self):
super().tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_router_z_loss_dtype(self):
x = tf.constant([[[10.0, 5.0]]], dtype=tf.float32)
y = moe._router_z_loss(x)
expected = (5 + np.log(np.exp(5) + 1))**2
self.assertAllClose(expected, y, atol=1e-7)
x = tf.constant([[[10.0, 5.0]]], dtype=tf.bfloat16)
y = moe._router_z_loss(x)
expected = 100.0
self.assertAllClose(expected, y, atol=1e-7)
def test_router_z_loss_shape(self):
x = make_input_ones(2, 5, 7)
y = moe._router_z_loss(x)
expected = (np.log(7) + 1)**2
self.assertAllClose(expected, y, atol=1e-7)
def test_experts_choose_masked_router_dtype_shape(self):
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
num_groups = 2
tokens_per_group = 3
hidden_dim = tokens_per_group
num_experts = tokens_per_group
expert_capacity = 2
x = np.zeros([num_groups, tokens_per_group, hidden_dim])
x[0, 0, 0] += 1
x[0, :2, :2] += 1
x[1, 1:, 1:] += 1
x[1, -1, -1] += 1
router = moe.ExpertsChooseMaskedRouter(
num_experts=num_experts,
jitter_noise=0.1,
use_bias=True,
kernel_initializer=tf.keras.initializers.get('identity'),
bias_initializer=tf.keras.initializers.get('ones'))
router_mask = router(x, expert_capacity=expert_capacity, training=False)
self.assertDTypeEqual(router_mask.dispatch_mask, tf.bfloat16)
self.assertDTypeEqual(router_mask.combine_array, tf.bfloat16)
expect_shape = [num_groups, tokens_per_group, num_experts, expert_capacity]
self.assertEqual(expect_shape, router_mask.dispatch_mask.shape)
self.assertEqual(expect_shape, router_mask.combine_array.shape)
# top_k call may not be sorted, so can't compare the output directly
# Check that the output contains only 0s and 1s
out_dm = router_mask.dispatch_mask.numpy()
self.assertSetEqual({0, 1}, set(out_dm.flatten().astype(np.int32)))
# Check that the right tokens for selected
out_dm_indices = np.dot(
out_dm.transpose((0, 2, 3, 1)), np.arange(tokens_per_group))
# Shape [num_groups, num_experts, expert_capacity]
self.assertSetEqual({0, 1}, set(out_dm_indices[0, 0, :].astype(np.int32)))
self.assertSetEqual({1, 2}, set(out_dm_indices[0, 1, :].astype(np.int32)))
self.assertSetEqual({1, 2}, set(out_dm_indices[0, 2, :].astype(np.int32)))
self.assertSetEqual({0, 1}, set(out_dm_indices[1, 0, :].astype(np.int32)))
self.assertSetEqual({0, 1}, set(out_dm_indices[1, 1, :].astype(np.int32)))
self.assertSetEqual({1, 2}, set(out_dm_indices[1, 2, :].astype(np.int32)))
out_ca = router_mask.combine_array.numpy()
out_ca = np.dot(out_ca, np.ones((expert_capacity,)))
expected_combine_array = np.array(
[[[0.66, 0.0, 0.0], [0.42, 0.42, 0.16], [0.0, 0.33, 0.33]],
[[0.33, 0.33, 0.0], [0.16, 0.42, 0.42], [0.0, 0.0, 0.66]]])
self.assertAllClose(expected_combine_array, out_ca, atol=1e-2)
def test_feed_forward_shape_and_vars(self):
config = small_config()
layer = moe.FeedForward(d_ff=config.d_ff, dropout_rate=config.dropout_rate)
inputs = make_input_ones()
outputs = layer(inputs)
self.assertAllEqual(tf.shape(inputs), tf.shape(outputs))
var_names = sorted([v.name for v in layer.trainable_variables])
self.assertAllEqual(['feed_forward/intermediate/bias:0',
'feed_forward/intermediate/kernel:0',
'feed_forward/output/bias:0',
'feed_forward/output/kernel:0'], var_names)
def test_feed_forward_manual(self):
config = small_config()
layer = moe.FeedForward(
d_ff=config.d_ff,
dropout_rate=config.dropout_rate,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.get('ones'),
bias_initializer=tf.keras.initializers.get('ones'))
inputs = make_input_ones(1, 2, 3)
outputs = layer(inputs, training=False)
manual_outputs = tf.constant([[[129.0, 129.0, 129.0],
[129.0, 129.0, 129.0]]])
self.assertAllClose(manual_outputs, outputs, atol=1e-7)
def test_feed_forward_experts_shape_and_vars(self):
config = small_config()
layer = moe.FeedForwardExperts(
num_experts=config.num_experts,
d_ff=config.expert_d_ff,
dropout_rate=config.expert_dropout_rate)
inputs = make_experts_input_ones()
outputs = layer(inputs)
self.assertAllEqual(tf.shape(inputs), tf.shape(outputs))
var_names = sorted([v.name for v in layer.trainable_variables])
self.assertAllEqual(['experts/intermediate/bias:0',
'experts/intermediate/kernel:0',
'experts/output/bias:0',
'experts/output/kernel:0'], var_names)
def test_feed_forward_experts_manual(self):
config = small_config()
layer = moe.FeedForwardExperts(
num_experts=1,
d_ff=config.expert_d_ff,
dropout_rate=config.expert_dropout_rate,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.get('ones'),
bias_initializer=tf.keras.initializers.get('ones'))
inputs = make_experts_input_ones(1, 1, 2, 3)
outputs = layer(inputs, training=False)
manual_outputs = tf.constant([[[[133.0, 133.0, 133.0],
[133.0, 133.0, 133.0]]]])
self.assertAllClose(manual_outputs, outputs, atol=1e-7)
def test_moe_layer(self):
config = small_config()
experts = moe.FeedForwardExperts(
num_experts=config.num_experts,
d_ff=config.expert_d_ff,
dropout_rate=config.expert_dropout_rate)
router = moe.ExpertsChooseMaskedRouter(
config.num_experts,
jitter_noise=config.jitter_noise)
moe_layer = moe.MoeLayer(
experts,
router,
train_capacity_factor=config.train_capacity_factor,
eval_capacity_factor=config.eval_capacity_factor,
max_group_size=config.max_group_size,
min_expert_capacity=config.min_expert_capacity)
inputs = make_input_ones()
with self.assertLogs('absl', level='INFO') as cm:
outputs = moe_layer(inputs, training=True)
self.assertAllEqual(tf.shape(inputs), tf.shape(outputs))
self.assertEqual(cm.output, [
('INFO:absl:Selected group_size=5 and num_groups=4 for input '
'num_tokens=20, max_group_size=9, num_experts=2.'),
('INFO:absl:Selected expert_capacity=2 for num_experts=2 and '
'training=True.')])
var_names = sorted([v.name for v in moe_layer.trainable_variables])
self.assertAllEqual(['moe/experts/intermediate/bias:0',
'moe/experts/intermediate/kernel:0',
'moe/experts/output/bias:0',
'moe/experts/output/kernel:0',
'moe/router/router_weights/bias:0',
'moe/router/router_weights/kernel:0'], var_names)
self.assertLen(moe_layer.losses, 1)
metrics = [metric.name for metric in moe_layer.metrics]
self.assertSetEqual(
{
'router_z_loss', 'load_balancing_loss',
'fraction_tokens_left_behind', 'router_confidence', 'expert_usage'
}, set(metrics))
def test_moe_layer_with_backbone(self):
config = small_config()
experts = moe.FeedForwardExperts(
num_experts=config.num_experts,
d_ff=config.expert_d_ff,
dropout_rate=config.expert_dropout_rate)
router = moe.ExpertsChooseMaskedRouter(
config.num_experts,
jitter_noise=config.jitter_noise)
moe_layer = moe.MoeLayer(
experts,
router,
train_capacity_factor=config.train_capacity_factor,
eval_capacity_factor=config.eval_capacity_factor,
max_group_size=config.max_group_size,
min_expert_capacity=config.min_expert_capacity)
layer = moe.MoeLayerWithBackbone(moe_layer, config.backbone_d_ff)
inputs = make_input_ones()
outputs = layer(inputs)
self.assertAllEqual(tf.shape(inputs), tf.shape(outputs))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,7 @@
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import masked_softmax
......@@ -48,7 +49,7 @@ class VotingAttention(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(VotingAttention, self).__init__(**kwargs)
super().__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
......@@ -60,26 +61,28 @@ class VotingAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
self._query_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="query",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
self._key_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="key",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
super(VotingAttention, self).build(unused_input_shapes)
super().build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
......@@ -120,7 +123,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
"""
def _build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(rank) # pytype: disable=attribute-error # typed-keras
super()._build_attention(rank) # pytype: disable=attribute-error # typed-keras
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -47,7 +47,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
scale_factor=None,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._initializer = initializer
......@@ -62,7 +62,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"use_one_hot": self._use_one_hot,
"scale_factor": self._scale_factor,
}
base_config = super(OnDeviceEmbedding, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
......@@ -72,7 +72,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape)
super().build(input_shape)
def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1])
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pack sequence optimization on accelerators."""
from typing import Dict
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import rezero_transformer
from official.nlp.modeling.layers import self_attention_mask
from official.nlp.modeling.layers import transformer_encoder_block
from official.nlp.modeling.layers import transformer_scaffold
@tf.keras.utils.register_keras_serializable(package='Text')
class PackBertEmbeddings(tf.keras.layers.Layer):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
def __init__(self, pack_sequences: int, **kwargs):
super().__init__(**kwargs)
self.pack_sequences = pack_sequences
def call(self, input_embeddings: tf.Tensor,
input_mask: tf.Tensor) -> Dict[str, tf.Tensor]:
batch_size, seq_len, embedding_dim = tf_utils.get_shape_list(
input_embeddings, expected_rank=3)
reduced_batch_size = batch_size // self.pack_sequences
packed_seq_len = self.pack_sequences * seq_len
packed_embeddings = tf.reshape(
input_embeddings, [reduced_batch_size, packed_seq_len, embedding_dim])
input_mask = tf.reshape(input_mask, [reduced_batch_size, packed_seq_len])
example_ids = 1 + tf.range(self.pack_sequences)
# Shape: [batch_size, seq_len, pack_sequences].
example_ids = tf.tile(example_ids[None, :, None],
[reduced_batch_size, 1, seq_len])
example_ids = tf.reshape(example_ids, [reduced_batch_size, packed_seq_len])
example_ids = tf.where(
tf.math.equal(input_mask, 0), tf.zeros_like(example_ids), example_ids)
packing_mask = tf.cast(
tf.equal(
tf.expand_dims(example_ids, 2), tf.expand_dims(example_ids, 1)),
dtype=tf.bool)
attention_mask = self_attention_mask.get_mask(
packed_embeddings, input_mask, dtype=tf.bool)
combined_attention_mask = tf.cast(
tf.math.logical_and(attention_mask, packing_mask), tf.float32)
return dict(
packed_embeddings=packed_embeddings,
combined_attention_mask=combined_attention_mask)
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedTransformerEncoderBlock(
transformer_encoder_block.TransformerEncoderBlock):
"""Transformer layer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError('StridedTransformerEncoderBlock does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if self._use_query_residual:
attention_output = source_tensor + attention_output
else:
if self._use_query_residual:
attention_output = target_tensor + attention_output
attention_output = self._attention_layer_norm(attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedReZeroTransformer(rezero_transformer.ReZeroTransformer):
"""ReZeroTransformer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError(f'{self.__class__} does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError(f'Unexpected inputs to {self.__class__} with '
f'length at {len(inputs)}.')
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._inner_activation_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
layer_output = attention_output + tf.cast(self._rezero_a_ffn * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedTransformerScaffold(transformer_scaffold.TransformerScaffold):
"""TransformerScaffold for packing optimization to stride over inputs."""
def call(self, inputs, stride: tf.Tensor, training=None):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if key_value is None:
key_value = input_tensor
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor, training=training)
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
target_tensor = input_tensor[:, ::stride, :]
attention_output = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
training=training)
attention_output = self._attention_dropout(
attention_output, training=training)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(
target_tensor + attention_output, training=training)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(
attention_output, training=training)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output, training=training)
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(
layer_output + attention_output, training=training)
else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(
attention_output, training=training)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(
attention_output, training=training)
return layer_output
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pack_optimization."""
import tensorflow as tf
from official.nlp.modeling.layers import pack_optimization
class PackOptimizationTest(tf.test.TestCase):
def test_bert_embedding_packing(self):
batch_size, seq_len, embed_dim = 2, 4, 8
pack_sequences = 2
token_and_position_embed = tf.ones((batch_size, seq_len, embed_dim),
dtype=tf.float32)
input_mask = tf.ones((batch_size, seq_len), dtype=tf.int32)
layer = pack_optimization.PackBertEmbeddings(pack_sequences=pack_sequences)
outputs = layer(token_and_position_embed, input_mask)
self.assertEqual(outputs["packed_embeddings"].shape, (1, 8, embed_dim))
self.assertEqual(outputs["combined_attention_mask"].shape, (1, 8, 8))
def test_strided_transformer_encoder_block(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
transformer = pack_optimization.StridedTransformerEncoderBlock(
num_attention_heads=2, inner_dim=4, inner_activation="relu")
outputs = transformer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
def test_strided_rezero_transformer(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
transformer = pack_optimization.StridedReZeroTransformer(
num_attention_heads=2, inner_dim=4, inner_activation="relu")
outputs = transformer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
def test_strided_scaffold(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
test_layer = pack_optimization.StridedTransformerScaffold(
num_attention_heads=2,
inner_dim=128,
inner_activation="relu")
outputs = test_layer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment