Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
772
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2444 additions
and
122 deletions
+2444
-122
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+488
-53
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+149
-10
official/nlp/modeling/layers/masked_lm.py
official/nlp/modeling/layers/masked_lm.py
+5
-4
official/nlp/modeling/layers/masked_lm_test.py
official/nlp/modeling/layers/masked_lm_test.py
+1
-1
official/nlp/modeling/layers/masked_softmax.py
official/nlp/modeling/layers/masked_softmax.py
+3
-3
official/nlp/modeling/layers/masked_softmax_test.py
official/nlp/modeling/layers/masked_softmax_test.py
+1
-1
official/nlp/modeling/layers/mat_mul_with_margin.py
official/nlp/modeling/layers/mat_mul_with_margin.py
+3
-3
official/nlp/modeling/layers/mat_mul_with_margin_test.py
official/nlp/modeling/layers/mat_mul_with_margin_test.py
+1
-1
official/nlp/modeling/layers/mixing.py
official/nlp/modeling/layers/mixing.py
+283
-0
official/nlp/modeling/layers/mixing_test.py
official/nlp/modeling/layers/mixing_test.py
+109
-0
official/nlp/modeling/layers/mobile_bert_layers.py
official/nlp/modeling/layers/mobile_bert_layers.py
+51
-31
official/nlp/modeling/layers/mobile_bert_layers_test.py
official/nlp/modeling/layers/mobile_bert_layers_test.py
+1
-1
official/nlp/modeling/layers/moe.py
official/nlp/modeling/layers/moe.py
+761
-0
official/nlp/modeling/layers/moe_test.py
official/nlp/modeling/layers/moe_test.py
+255
-0
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+11
-8
official/nlp/modeling/layers/multi_channel_attention_test.py
official/nlp/modeling/layers/multi_channel_attention_test.py
+1
-1
official/nlp/modeling/layers/on_device_embedding.py
official/nlp/modeling/layers/on_device_embedding.py
+4
-4
official/nlp/modeling/layers/on_device_embedding_test.py
official/nlp/modeling/layers/on_device_embedding_test.py
+1
-1
official/nlp/modeling/layers/pack_optimization.py
official/nlp/modeling/layers/pack_optimization.py
+250
-0
official/nlp/modeling/layers/pack_optimization_test.py
official/nlp/modeling/layers/pack_optimization_test.py
+66
-0
No files found.
Too many changes to show.
To preserve performance only
772 of 772+
files are displayed.
Plain diff
Email patch
official/nlp/modeling/layers/kernel_attention.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,6 +18,8 @@ import functools
import
math
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_NUMERIC_STABLER
=
1e-6
...
...
@@ -39,6 +41,236 @@ class KernelMask(tf.keras.layers.Layer):
return
mask
def
pad_to_chunk_length
(
tensor
,
axis
,
chunk_length
,
padding
=
None
):
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
Args:
tensor: Input tensor to pad.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
padding: Pad the input tensor across the axis from either left or right if
padding is set to "left" or "right"; applies no padding if padding is set
to None. In the latter case, the axis dimension of the input tensor must
be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
if
padding
is
None
:
return
tensor
shape
=
tf
.
shape
(
tensor
)
rank
=
tf
.
rank
(
tensor
)
if
axis
<
0
:
axis
+=
rank
axis_length
=
shape
[
axis
]
pad_length
=
-
axis_length
%
chunk_length
if
padding
==
"right"
:
axis_paddings
=
[[
0
,
pad_length
]]
elif
padding
==
"left"
:
axis_paddings
=
[[
pad_length
,
0
]]
else
:
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
,
\"
right
\"
or None."
)
paddings
=
tf
.
concat
([
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)
],
axis
=
0
)
return
tf
.
pad
(
tensor
,
paddings
)
def
split_tensor_into_chunks
(
tensor
,
axis
,
chunk_length
):
"""Reshape tensor along given axis using chunk_length.
Args:
tensor: Input tensor.
axis: Reshape tensor along this axis.
chunk_length: Split the axis into [axis/chunk_length, chunk_length]
Returns:
Reshaped tensor.
"""
shape
=
tf
.
shape
(
tensor
)
num_chunks
=
shape
[
axis
]
//
chunk_length
new_shape
=
tf
.
concat
(
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
return
tf
.
reshape
(
tensor
,
new_shape
)
def
rectangular_window_sum
(
tensor
,
window_length
):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum
=
tf
.
cumsum
(
tensor
,
axis
=-
4
)
tensor_winsum
=
tensor_cumsum
-
tf
.
pad
(
tensor_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
return
tensor_winsum
def
weighted_window_sum
(
tensor
,
window_length
,
window_weights
):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape
=
tf
.
shape
(
tensor
)
tensor_2d
=
tf
.
reshape
(
tensor
,
[
tensor_shape
[
0
],
tensor_shape
[
1
],
1
,
-
1
])
# Apply the same weights to all channels.
conv_filter
=
tf
.
tile
(
tf
.
reshape
(
window_weights
,
[
-
1
,
1
,
1
,
1
]),
multiples
=
[
1
,
1
,
tf
.
shape
(
tensor_2d
)[
-
1
],
1
])
tensor_winsum_2d
=
tf
.
nn
.
depthwise_conv2d
(
tensor_2d
,
conv_filter
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
[[
0
,
0
],
[
window_length
-
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
# Unflatten the channels dimension into the original shape.
tensor_winsum
=
tf
.
reshape
(
tensor_winsum_2d
,
tensor_shape
)
return
tensor_winsum
def
causal_windowed_performer_attention
(
query_matrix
,
key_matrix
,
value_matrix
,
chunk_length
,
window_length
,
window_decay
=
None
,
padding
=
None
,
cache
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of
chunk_length tokens (thus: T = N * chunk_length). Within each chunk,
we apply bidirectional (non-causal) Performers’ implicit attention
and we model relationships between different chunks using
Performers’ causal attention. We consider windowed causal variant of
performer, where the current chunk attends only to the window of
window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=2. In
this example 1 indicates attention is computed between the pair
while 0 indicates attention is not computed between the pairs:
111000000
111000000
111000000
111111000
111111000
111111000
000111111
000111111
000111111
User can ensure sequence_length is divisible by chunk_length or use
padding="left"/"right" to pad the sequence length either at the left
or right respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, H, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, H, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set, exponentially
decay past attention window values by this factor before summation.
padding: Pad the query, value and key input tensors across the axis from
either left or right if padding is set to "left" or "right"; apply no
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
if
cache
is
None
:
# Training
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
padding
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
padding
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
padding
)
new_shape
=
tf
.
shape
(
value_matrix
)
chunked_query_matrix
=
split_tensor_into_chunks
(
query_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix
=
split_tensor_into_chunks
(
key_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix
=
split_tensor_into_chunks
(
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"BTCHD,BTCHO->BTHDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
k_sum
=
tf
.
math
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
,
keepdims
=
True
)
if
window_decay
is
None
:
kp_v_winsum
=
rectangular_window_sum
(
kp_v
,
window_length
)
k_winsum
=
rectangular_window_sum
(
k_sum
,
window_length
)
else
:
# Compute exponentially decaying weights.
decaying_weights
=
tf
.
math
.
pow
(
tf
.
convert_to_tensor
(
window_decay
,
dtype
=
value_matrix
.
dtype
),
tf
.
range
(
window_length
-
1
,
-
1
,
delta
=-
1
,
dtype
=
value_matrix
.
dtype
))
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
tf
.
reshape
(
attention
,
new_shape
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
# Queued window cache (drop instead of decay) not yet supported.
else
:
# Streaming
if
window_decay
is
None
or
window_decay
>
1.0
or
window_decay
<
0.0
:
raise
ValueError
(
"window_decay should be in (0.0, 1.0) and not None."
)
kv
=
window_decay
*
cache
[
"kv"
]
+
tf
.
einsum
(
"BTHD,BTHO->BHOD"
,
key_matrix
,
value_matrix
)
cache
[
"kv"
]
=
kv
k_sum
=
window_decay
*
cache
[
"k_sum"
]
+
tf
.
reduce_sum
(
key_matrix
,
axis
=
1
)
cache
[
"k_sum"
]
=
k_sum
denominator
=
tf
.
einsum
(
"BTHD,BHD->BTH"
,
query_matrix
,
k_sum
)
attention
=
tf
.
einsum
(
"BTHD,BHOD,BTH->BTHO"
,
query_matrix
,
kv
,
1.0
/
(
denominator
+
_NUMERIC_STABLER
))
return
attention
def
create_projection_matrix
(
m
,
d
,
seed
=
None
):
r
"""Constructs the matrix of random projections.
...
...
@@ -56,8 +288,8 @@ def create_projection_matrix(m, d, seed=None):
The matrix of random projections of the shape [m, d].
"""
nb_full_blocks
=
math
.
ceil
(
m
/
d
)
block_list
=
tf
.
TensorArray
(
tf
.
float32
,
size
=
tf
.
cast
(
nb_full_blocks
,
dtype
=
tf
.
int32
))
block_list
=
tf
.
TensorArray
(
tf
.
float32
,
size
=
tf
.
cast
(
nb_full_blocks
,
dtype
=
tf
.
int32
))
stateful
=
False
if
seed
is
None
:
stateful
=
True
...
...
@@ -85,11 +317,13 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
def
_generalized_kernel
(
x
,
y
,
is_query
,
projection_matrix
,
f
,
h
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
f: A non-linear function applied on x or projected x.
...
...
@@ -99,7 +333,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns:
Transformed feature.
"""
del
y
del
is_query
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
else
:
...
...
@@ -108,8 +343,124 @@ def _generalized_kernel(x, projection_matrix, f, h):
tf
.
cast
(
tf
.
shape
(
projection_matrix
)[
0
],
tf
.
float32
))
def
expplus
(
data_orig
,
other_data
,
is_query
,
projection_matrix
=
None
,
numerical_stabilizer
=
0.000001
,
normalize_data
=
True
,
numerical_renormalizer
=
True
,
extra_renormalize_exp_fun
=
False
):
"""FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 .
Args:
data_orig: data tensor of shape [B,T,H,D] for which random features aree to
be computed
other_data: additional tensor of the shape [B,F,H,D] used to collect stats
to determine the exact instantiation of the random feature mechanism
is_query: boolean indicating whether <data_orig> tensor is a query tensor
projection_matrix: tensor of the shape [M,D] encoding random projections for
random features (M stands for the number of random features)
numerical_stabilizer: numerical stabilizer for the kernel features
normalize_data: whether to sqrt-d-normalize queries/keys as in the regular
attention
numerical_renormalizer: whether to apply additional renormalization for
numerical stability
extra_renormalize_exp_fun: extra renormalizer for the exponential mapping
applied to construct random features
Returns:
Random feature map tensor for the unbiased softmax-kernel estimation.
"""
data
=
data_orig
if
projection_matrix
is
None
:
return
data_orig
projection_matrix
=
tf
.
cast
(
projection_matrix
,
data
.
dtype
)
if
normalize_data
:
data_normalizer
=
1.0
/
tf
.
math
.
sqrt
(
(
tf
.
math
.
sqrt
(
tf
.
dtypes
.
cast
(
data
.
shape
[
-
1
],
data
.
dtype
))))
else
:
data_normalizer
=
1.0
lengths
=
tf
.
math
.
square
(
data
)
lengths
=
tf
.
reduce_sum
(
lengths
,
axis
=
tf
.
keras
.
backend
.
ndim
(
data
)
-
1
)
lengths
=
tf
.
expand_dims
(
lengths
,
axis
=
tf
.
keras
.
backend
.
ndim
(
data
)
-
1
)
lengths
=
tf
.
math
.
sqrt
(
lengths
)
data
/=
lengths
ratio
=
1.0
/
tf
.
math
.
sqrt
(
tf
.
dtypes
.
cast
(
projection_matrix
.
shape
[
0
],
data
.
dtype
))
data_dash
=
tf
.
einsum
(
"blhd,md->blhm"
,
data_normalizer
*
data
,
projection_matrix
)
diag_data
=
tf
.
math
.
square
(
data
)
diag_data
=
tf
.
math
.
reduce_sum
(
diag_data
,
axis
=
tf
.
keras
.
backend
.
ndim
(
data
)
-
1
)
diag_data
=
(
diag_data
/
2.0
)
*
data_normalizer
*
data_normalizer
diag_data
=
tf
.
expand_dims
(
diag_data
,
axis
=
tf
.
keras
.
backend
.
ndim
(
data
)
-
1
)
# Calculating coefficients A, B of the FAVOR++ mechanism:
_
,
l
,
_
,
_
=
tf_utils
.
get_shape_list
(
data_orig
)
l
=
tf
.
cast
(
l
,
dtype
=
tf
.
float32
)
first_sum_of_squares
=
tf
.
math
.
square
(
data
)
first_sum_of_squares
=
tf
.
math
.
reduce_sum
(
first_sum_of_squares
,
axis
=
(
1
,
-
1
),
keepdims
=
True
)
first_sum_of_squares
*=
(
data_normalizer
*
data_normalizer
)
first_sum_of_squares
/=
l
# data.shape[1]
second_sum_of_squares
=
tf
.
math
.
square
(
other_data
)
second_sum_of_squares
=
tf
.
math
.
reduce_sum
(
second_sum_of_squares
,
axis
=
(
1
,
-
1
),
keepdims
=
True
)
second_sum_of_squares
*=
(
data_normalizer
*
data_normalizer
)
second_sum_of_squares
/=
l
# other_data.shape[1]
data_sum
=
tf
.
math
.
reduce_sum
(
data
,
axis
=
(
1
,),
keepdims
=
True
)
other_data_sum
=
tf
.
math
.
reduce_sum
(
other_data
,
axis
=
(
1
,),
keepdims
=
True
)
d_prod
=
tf
.
einsum
(
"blhd,blhd->blh"
,
data_sum
,
other_data_sum
)
d_prod
=
tf
.
expand_dims
(
d_prod
,
axis
=-
1
)
d_prod
*=
(
data_normalizer
*
data_normalizer
)
d_prod
*=
(
2.0
/
(
l
*
l
))
ave
=
first_sum_of_squares
+
second_sum_of_squares
+
d_prod
dim
=
projection_matrix
.
shape
[
-
1
]
a_coeff
=
(
1.0
/
(
4.0
*
ave
))
*
(
tf
.
math
.
sqrt
((
2.0
*
ave
+
dim
)
*
(
2.0
*
ave
+
dim
)
+
8.0
*
dim
*
ave
)
-
2.0
*
ave
-
dim
)
a_coeff
=
(
1.0
-
1.0
/
a_coeff
)
/
8.0
b_coeff
=
tf
.
math
.
sqrt
(
1.0
-
4.0
*
a_coeff
)
d_coeff
=
tf
.
math
.
pow
(
1.0
-
4.0
*
a_coeff
,
dim
/
4.0
)
a_coeff
=
tf
.
stop_gradient
(
a_coeff
)
b_coeff
=
tf
.
stop_gradient
(
b_coeff
)
d_coeff
=
tf
.
stop_gradient
(
d_coeff
)
# Calculating diag_omega for the FAVOR++ mechanism:
diag_omega
=
tf
.
math
.
square
(
projection_matrix
)
diag_omega
=
tf
.
math
.
reduce_sum
(
diag_omega
,
axis
=
tf
.
keras
.
backend
.
ndim
(
projection_matrix
)
-
1
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
a_coeff
*
diag_omega
if
numerical_renormalizer
:
if
is_query
:
last_dims_t
=
(
len
(
data_dash
.
shape
)
-
1
,)
stab
=
b_coeff
*
tf
.
math
.
reduce_max
(
data_dash
,
axis
=
last_dims_t
,
keepdims
=
True
)
else
:
stab
=
b_coeff
*
tf
.
math
.
reduce_max
(
data_dash
,
keepdims
=
True
)
if
extra_renormalize_exp_fun
:
extra_stab
=
tf
.
reduce_max
(
diag_data
,
axis
=
1
,
keepdims
=
True
)
stab
=
tf
.
math
.
maximum
(
stab
,
extra_stab
)
data_dash
=
ratio
*
d_coeff
*
(
tf
.
math
.
exp
(
b_coeff
*
data_dash
-
stab
-
diag_data
+
diag_omega
)
+
numerical_stabilizer
)
else
:
data_dash
=
ratio
*
d_coeff
*
(
tf
.
math
.
exp
(
b_coeff
*
data_dash
-
diag_data
+
diag_omega
)
+
numerical_stabilizer
)
return
data_dash
# pylint: disable=g-long-lambda
_TRANSFORM_MAP
=
{
_CAUSAL_SUPPORT
_TRANSFORM_MAP
=
{
"elu"
:
functools
.
partial
(
_generalized_kernel
,
...
...
@@ -117,19 +468,22 @@ _TRANSFORM_MAP = {
h
=
lambda
x
:
1
),
"relu"
:
functools
.
partial
(
_generalized_kernel
,
f
=
tf
.
keras
.
activations
.
relu
,
h
=
lambda
x
:
1
),
_generalized_kernel
,
# Improve numerical stability and avoid NaNs in some cases by adding
# a tiny epsilon.
f
=
lambda
x
:
tf
.
keras
.
activations
.
relu
(
x
)
+
1e-3
,
h
=
lambda
x
:
1
),
"square"
:
functools
.
partial
(
_generalized_kernel
,
f
=
tf
.
math
.
square
,
h
=
lambda
x
:
1
),
functools
.
partial
(
_generalized_kernel
,
f
=
tf
.
math
.
square
,
h
=
lambda
x
:
1
),
"exp"
:
functools
.
partial
(
_generalized_kernel
,
# Avoid exp explosion by shifting.
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
),
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
),
"expmod"
:
functools
.
partial
(
_generalized_kernel
,
...
...
@@ -142,6 +496,16 @@ _TRANSFORM_MAP = {
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
=
{
"expplus"
:
expplus
,
}
_TRANSFORM_MAP
=
{
**
_CAUSAL_SUPPORT_TRANSFORM_MAP
,
**
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda
...
...
@@ -154,6 +518,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
- random/deterministic projection
Chefs' Random Tables: Non-Trigonometric Random Features
(https://arxiv.org/abs/2205.15317)
- expplus (OPRF mechanism)
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
(https://arxiv.org/abs/2006.16236)
...
...
@@ -178,13 +545,19 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq
=
False
,
begin_kernel
=
0
,
scale
=
None
,
scale_by_length
=
False
,
use_causal_windowed
=
False
,
causal_chunk_length
=
1
,
causal_window_length
=
3
,
causal_window_decay
=
None
,
causal_padding
=
None
,
**
kwargs
):
r
"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "exp
mod
",
"identity".
feature_transform: A non-linear transform of the keys and qu
e
ries.
Possible transforms are "elu", "relu", "square", "exp", "exp
plus
",
"expmod",
"identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
...
...
@@ -194,12 +567,28 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
(default
option).
very short sequences or not; in most cases this should be False
(default
option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
use_causal_windowed: If true perform windowed causal attention. See
causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor before
summation.
causal_padding: Pad the query, value and key input tensors across the axis
from either left or right if padding is set to "left" or "right"; apply
no padding if padding is set to None. In the latter case, the axis
dimension of the query, value and key input tensors must be divisible by
the chunk_length.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
...
@@ -214,6 +603,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_redraw
=
redraw
self
.
_is_short_seq
=
is_short_seq
self
.
_begin_kernel
=
begin_kernel
self
.
_scale_by_length
=
scale_by_length
# We use the seed for two scenarios:
# 1. inference
# 2. no redraw
...
...
@@ -228,6 +618,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
self
.
_key_dim
,
tf
.
constant
([
self
.
_seed
,
self
.
_seed
+
1
]))
self
.
use_causal_windowed
=
use_causal_windowed
self
.
causal_chunk_length
=
causal_chunk_length
self
.
causal_window_length
=
causal_window_length
self
.
causal_window_decay
=
causal_window_decay
self
.
causal_padding
=
causal_padding
if
self
.
use_causal_windowed
and
self
.
_is_short_seq
:
raise
ValueError
(
"use_causal_windowed and short_seq methods are mutually exclusive"
)
def
_compute_attention
(
self
,
query
,
...
...
@@ -236,6 +634,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform
,
is_short_seq
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
,
numeric_stabler
=
_NUMERIC_STABLER
):
"""Applies kernel attention with query, key, value tensors.
...
...
@@ -252,9 +651,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
...
...
@@ -263,6 +664,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_redraw
and
training
:
projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
...
...
@@ -270,35 +672,53 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
projection_matrix
=
self
.
_projection_matrix
if
self
.
_scale_by_length
:
scale
=
tf
.
math
.
log
(
tf
.
reduce_sum
(
attention_mask
,
axis
=-
1
))
*
self
.
_scale
/
math
.
log
(
512
)
scale
=
tf
.
reshape
(
scale
,
[
-
1
,
1
,
1
,
1
])
else
:
scale
=
self
.
_scale
if
is_short_seq
:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
query
*
self
.
_
scale
query
=
query
*
scale
else
:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_
scale
)
query
*=
math
.
sqrt
(
self
.
_
scale
)
key
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
key_prime
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
query
,
False
,
projection_matrix
)
query_prime
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
key
,
True
,
projection_matrix
)
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
key
_prime
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
_prime
,
attention_mask
)
if
is_short_seq
:
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
_prime
,
key_prime
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
elif
self
.
use_causal_windowed
:
attention_output
=
causal_windowed_performer_attention
(
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
padding
=
self
.
causal_padding
,
cache
=
cache
)
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
_prime
,
value
)
denominator
=
1.0
/
(
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
_NUMERIC_STABLER
)
attention_output
=
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
_prime
,
tf
.
reduce_sum
(
key_prime
,
axis
=
1
))
+
_NUMERIC_STABLER
)
attention_output
=
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query_prime
,
kv
,
denominator
)
return
attention_output
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
...
...
@@ -313,15 +733,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_output_dense_softmax
=
self
.
_make_output_dense
(
self
.
_query_shape
.
rank
-
1
,
common_kwargs
,
self
.
_query_shape
.
rank
-
1
,
common_kwargs
,
name
=
"attention_output_softmax"
)
self
.
_dropout_softmax
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
):
"""Compute attention with kernel mechanism.
...
...
@@ -330,15 +747,32 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if
cache
is
not
None
:
if
training
:
raise
ValueError
(
"Cache is not supported when training is True."
)
if
not
self
.
use_causal_windowed
:
raise
ValueError
(
"Cache is not supported for non use_causal_windowed case."
)
if
self
.
_begin_kernel
:
raise
ValueError
(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated."
)
if
self
.
_feature_transform
in
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
:
raise
ValueError
(
"Cache is not supported for feature_transform %s"
%
(
self
.
_feature_transform
))
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
...
...
@@ -357,25 +791,26 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
if
self
.
_begin_kernel
>
0
:
attention_output_softmax
=
self
.
_compute_attention
(
query
[:,
:
self
.
_begin_kernel
],
key
,
value
,
"identity"
,
True
,
attention_mask
,
training
)
query
[:,
:
self
.
_begin_kernel
],
key
,
value
,
"identity"
,
True
,
attention_mask
,
training
)
attention_output_softmax
=
self
.
_dropout_softmax
(
attention_output_softmax
)
attention_output_softmax
=
self
.
_output_dense_softmax
(
attention_output_softmax
)
attention_output_kernel
=
self
.
_compute_attention
(
query
[:,
self
.
_begin_kernel
:],
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
query
[:,
self
.
_begin_kernel
:],
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_output_kernel
=
self
.
_dropout_layer
(
attention_output_kernel
)
attention_output_kernel
=
self
.
_output_dense
(
attention_output_kernel
)
attention_output_kernel
=
self
.
_output_dense
(
attention_output_kernel
)
attention_output
=
tf
.
concat
(
[
attention_output_softmax
,
attention_output_kernel
],
axis
=
1
)
else
:
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
cache
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'
relu
'
,
'
elu
'
,
'
exp
'
]
_FEATURE_TRANSFORM
=
[
"
relu
"
,
"
elu
"
,
"exp"
,
"
exp
plus"
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
...
...
@@ -30,9 +30,67 @@ _BEGIN_KERNEL = [0, 512]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
[
"relu"
,
"elu"
],
[
1
,
4
],
[
0.9
]))
def
test_causal_windowed_attention_projection_streaming
(
self
,
feature_transform
,
causal_chunk_length
,
causal_weight_decay
):
num_heads
=
12
key_dim
=
64
seq_length
=
16
num_chunks
=
seq_length
//
causal_chunk_length
causal_window_length
=
num_chunks
batch_size
=
2
training
=
False
num_random_features
=
0
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
False
,
is_short_seq
=
False
,
begin_kernel
=
False
,
use_causal_windowed
=
True
,
causal_chunk_length
=
causal_chunk_length
,
causal_window_length
=
causal_window_length
,
causal_window_decay
=
causal_weight_decay
,
causal_padding
=
None
,
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
),
seed
=
2
)
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
dim
=
num_random_features
if
num_random_features
>
0
else
key_dim
kv_cache
=
tf
.
zeros
(
(
batch_size
,
num_heads
,
dim
,
dim
))
k_sum_cache
=
tf
.
zeros
((
batch_size
,
num_heads
,
dim
))
stream_output
=
[]
cache
=
{
"kv"
:
kv_cache
,
"k_sum"
:
k_sum_cache
}
for
i
in
range
(
num_chunks
):
stream_output
.
append
(
test_layer
(
query
=
query
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
value
=
value
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
attention_mask
=
masks
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
],
cache
=
cache
,
training
=
training
))
stream_output
=
tf
.
concat
(
stream_output
,
axis
=
1
)
self
.
assertAllClose
(
output
,
stream_output
)
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
def
test_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
begin_kernel
):
...
...
@@ -60,6 +118,41 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
([
"relu"
,
"exp"
],
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
],
[
None
,
0.97
],
[
None
,
"left"
,
"right"
]))
def
test_causal_windowed_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
,
causal_window_decay
,
causal_padding
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
redraw
,
is_short_seq
=
False
,
begin_kernel
=
begin_kernel
,
use_causal_windowed
=
True
,
causal_chunk_length
=
8
,
causal_window_length
=
3
,
causal_window_decay
=
causal_window_decay
,
causal_padding
=
causal_padding
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
0
],
_TRAINING
,
[
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
...
...
@@ -90,15 +183,41 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
([
128
,
512
])
def
test_attention_scale_by_length
(
self
,
seq_length
):
num_heads
=
12
key_dim
=
64
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
num_random_features
=
0
,
scale_by_length
=
True
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
test_layer
.
_scale_by_length
=
False
output_no_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
if
seq_length
==
512
:
# Equals because log(seq_length, base=512) = 1.0
self
.
assertAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
else
:
self
.
assertNotAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
def
test_unsupported_feature_transform
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'
Unsupported feature_transform.*
'
):
_
=
attention
.
KernelAttention
(
feature_transform
=
'
test
'
)
with
self
.
assertRaisesRegex
(
ValueError
,
"
Unsupported feature_transform.*
"
):
_
=
attention
.
KernelAttention
(
feature_transform
=
"
test
"
)
def
test_redraw_true_no_projection
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'
There is nothing to redraw when num_random_features.*
'
):
ValueError
,
"
There is nothing to redraw when num_random_features.*
"
):
_
=
attention
.
KernelAttention
(
num_heads
=
2
,
key_dim
=
64
,
feature_transform
=
'
elu
'
,
num_heads
=
2
,
key_dim
=
64
,
feature_transform
=
"
elu
"
,
num_random_features
=
0
,
redraw
=
True
)
def
test_config
(
self
):
...
...
@@ -107,7 +226,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
'
exp
'
,
feature_transform
=
"
exp
"
,
num_random_features
=
128
,
is_short_seq
=
True
)
new_layer
=
attention
.
KernelAttention
.
from_config
(
...
...
@@ -115,5 +234,25 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
def
test_rectangular_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
rectangular_window_sum
(
x
,
3
)
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
2.
,
3.
,
3.
,
3.
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
def
test_weighted_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
weighted_window_sum
(
x
,
3
,
[
0.01
,
0.1
,
1.
])
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
1.1
,
1.11
,
1.11
,
1.11
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/masked_lm.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -47,7 +47,7 @@ class MaskedLM(tf.keras.layers.Layer):
output
=
'logits'
,
name
=
None
,
**
kwargs
):
super
(
MaskedLM
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
embedding_table
=
embedding_table
self
.
activation
=
activation
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
...
...
@@ -73,7 +73,7 @@ class MaskedLM(tf.keras.layers.Layer):
initializer
=
'zeros'
,
trainable
=
True
)
super
(
MaskedLM
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
call
(
self
,
sequence_data
,
masked_positions
):
masked_lm_input
=
self
.
_gather_indexes
(
sequence_data
,
masked_positions
)
...
...
@@ -115,7 +115,8 @@ class MaskedLM(tf.keras.layers.Layer):
flat_offsets
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_positions
=
tf
.
reshape
(
positions
+
tf
.
cast
(
flat_offsets
,
positions
.
dtype
),
[
-
1
])
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
[
batch_size
*
seq_length
,
width
])
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
...
...
official/nlp/modeling/layers/masked_lm_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/masked_softmax.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -53,7 +53,7 @@ class MaskedSoftmax(tf.keras.layers.Layer):
self
.
_normalization_axes
=
(
-
1
,)
else
:
self
.
_normalization_axes
=
normalization_axes
super
(
MaskedSoftmax
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
call
(
self
,
scores
,
mask
=
None
):
...
...
@@ -81,5 +81,5 @@ class MaskedSoftmax(tf.keras.layers.Layer):
'mask_expansion_axes'
:
self
.
_mask_expansion_axes
,
'normalization_axes'
:
self
.
_normalization_axes
}
base_config
=
super
(
MaskedSoftmax
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/modeling/layers/masked_softmax_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/mat_mul_with_margin.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -36,7 +36,7 @@ class MatMulWithMargin(tf.keras.layers.Layer):
logit_scale
=
1.0
,
logit_margin
=
0.0
,
**
kwargs
):
super
(
MatMulWithMargin
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
logit_scale
=
logit_scale
self
.
logit_margin
=
logit_margin
...
...
@@ -61,7 +61,7 @@ class MatMulWithMargin(tf.keras.layers.Layer):
config
=
{
'logit_scale'
:
self
.
logit_scale
,
'logit_margin'
:
self
.
logit_margin
}
config
.
update
(
super
(
MatMulWithMargin
,
self
).
get_config
())
config
.
update
(
super
().
get_config
())
return
config
@
classmethod
...
...
official/nlp/modeling/layers/mat_mul_with_margin_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/mixing.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based mixing layers.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
Mixing layers can be used as drop in replacements for self-attention layers. For
interoperability with attention layers, we use the same `query` and `value` call
signature.
Note: These mixing layers currently only support encoder stacks. Decoder stacks
can be supported in the future by utilizing the `value` inputs.
"""
import
enum
import
functools
from
typing
import
Callable
,
Tuple
,
Union
import
numpy
as
np
from
scipy
import
linalg
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
default_kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
2e-2
)
class
MixingMechanism
(
enum
.
Enum
):
"""Determines the type of mixing layer.
Possible options:
FOURIER: Fourier Transform mixing.
LINEAR: Mixing using dense matrix multiplications with learnable weights.
HARTLEY: Hartley Transform mixing.
"""
FOURIER
=
"fourier"
HARTLEY
=
"hartley"
LINEAR
=
"linear"
class
MixingLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Mixing layer base class.
This class cannot be used directly. It just specifies the API for mixing
layer subclasses. For interoperability with attention layers, we use the same
`query` and `value` call signature.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
"""
def
__init__
(
self
,
name
:
str
=
"mixing"
,
**
kwargs
):
"""Initializes layer.
Args:
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Calls the layer.
Subclasses should return tensors of shape
<float>[batch_size, max_seq_length, hidden_dim].
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Raises:
NotImplementedError. This class should not be called directly.
"""
raise
NotImplementedError
(
"Abstract method"
)
class
FourierTransformLayer
(
MixingLayer
):
"""Fourier Transform layer.
Applies 2D Fourier Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def
__init__
(
self
,
use_fft
:
bool
=
False
,
name
:
str
=
"fourier_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Fourier Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
use_fft
=
use_fft
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Picks the Fourier Transform implementation."""
self
.
fourier_transform
=
_pick_fourier_transform
(
self
.
use_fft
,
max_seq_length
=
input_shape
[
-
2
],
hidden_dim
=
input_shape
[
-
1
])
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Fourier Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
query
=
tf
.
cast
(
query
,
tf
.
complex64
)
return
tf
.
math
.
real
(
self
.
fourier_transform
(
query
))
class
HartleyTransformLayer
(
MixingLayer
):
"""Hartley Transform layer.
Applies 2D Hartley Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def
__init__
(
self
,
use_fft
:
bool
=
False
,
name
:
str
=
"hartley_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Hartley Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
use_fft
=
use_fft
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Picks the Fourier Transform implementation."""
self
.
fourier_transform
=
_pick_fourier_transform
(
self
.
use_fft
,
max_seq_length
=
input_shape
[
-
2
],
hidden_dim
=
input_shape
[
-
1
])
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Hartley Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
query
=
tf
.
cast
(
query
,
tf
.
complex64
)
frequencies
=
self
.
fourier_transform
(
query
)
return
tf
.
math
.
real
(
frequencies
)
-
tf
.
math
.
imag
(
frequencies
)
class
LinearTransformLayer
(
MixingLayer
):
"""Dense, linear transformation layer.
Applies matrix multiplications over sequence and hidden dimensions.
"""
def
__init__
(
self
,
kernel_initializer
:
_Initializer
=
default_kernel_initializer
,
name
:
str
=
"linear_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
kernel_initializer: Initialization scheme for kernel.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Creates the hidden and sequence matrix variables of the layer."""
self
.
mat_hidden
=
self
.
add_weight
(
shape
=
(
input_shape
[
-
1
],
input_shape
[
-
1
]),
initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
trainable
=
True
,
name
=
"hidden_kernel"
)
self
.
mat_seq
=
self
.
add_weight
(
shape
=
(
input_shape
[
-
2
],
input_shape
[
-
2
]),
initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
trainable
=
True
,
name
=
"seq_kernel"
)
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Linearly transformed `query` inputs with shape
<float>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
return
tf
.
einsum
(
"bij,jk,ni->bnk"
,
query
,
self
.
mat_hidden
,
self
.
mat_seq
)
def
_pick_fourier_transform
(
use_fft
:
bool
,
max_seq_length
:
int
,
hidden_dim
:
int
)
->
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]:
"""Returns FFT or DFT Fourier Transform implementation.
On TPUs, we recommend using the Discrete Fourier Transform (DFT) matrix
(use_fft=False), except for very long sequence lengths. On GPUs and CPUs, the
Fast Fourier Transform (use_fft=True) is generally optimal for all sequence
lengths.
Note: When using the FFT it is recommended to use a sequence length that is a
power of 2.
Args:
use_fft: If True, return FFT. Otherwise, return DFT matrix.
max_seq_length: Maximum sequence length of inputs. Only used if
use_fft=False.
hidden_dim: Size of hidden dimension of inputs. Only used if use_fft=False.
Returns:
Fourier Transform.
"""
if
use_fft
:
return
tf
.
signal
.
fft2d
else
:
dft_mat_seq
=
linalg
.
dft
(
max_seq_length
).
astype
(
np
.
complex64
)
dft_mat_hidden
=
linalg
.
dft
(
hidden_dim
).
astype
(
np
.
complex64
)
def
two_dim_matmul
(
x
:
tf
.
Tensor
,
matrix_dim_one
:
tf
.
Tensor
,
matrix_dim_two
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies 2D matrix multiplication to input tensors of rank >= 2."""
return
tf
.
einsum
(
"...ij,jk,ni->...nk"
,
tf
.
cast
(
x
,
tf
.
complex64
),
matrix_dim_two
,
matrix_dim_one
)
return
functools
.
partial
(
two_dim_matmul
,
matrix_dim_one
=
tf
.
convert_to_tensor
(
dft_mat_seq
),
matrix_dim_two
=
tf
.
convert_to_tensor
(
dft_mat_hidden
))
official/nlp/modeling/layers/mixing_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mixing.py."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
mixing
class
MixingTest
(
tf
.
test
.
TestCase
):
def
test_base_mixing_layer
(
self
):
inputs
=
tf
.
random
.
uniform
((
3
,
8
,
16
),
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
float32
)
with
self
.
assertRaisesRegex
(
NotImplementedError
,
"Abstract method"
):
_
=
mixing
.
MixingLayer
()(
query
=
inputs
,
value
=
inputs
)
def
test_fourier_layer
(
self
):
batch_size
=
4
max_seq_length
=
8
hidden_dim
=
16
inputs
=
tf
.
random
.
uniform
((
batch_size
,
max_seq_length
,
hidden_dim
),
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
float32
)
outputs
=
mixing
.
FourierTransformLayer
(
use_fft
=
True
)(
query
=
inputs
,
value
=
inputs
)
self
.
assertEqual
(
outputs
.
shape
,
(
batch_size
,
max_seq_length
,
hidden_dim
))
def
test_hartley_layer
(
self
):
batch_size
=
3
max_seq_length
=
16
hidden_dim
=
4
inputs
=
tf
.
random
.
uniform
((
batch_size
,
max_seq_length
,
hidden_dim
),
minval
=
0
,
maxval
=
12
,
dtype
=
tf
.
float32
)
outputs
=
mixing
.
HartleyTransformLayer
(
use_fft
=
True
)(
query
=
inputs
,
value
=
inputs
)
self
.
assertEqual
(
outputs
.
shape
,
(
batch_size
,
max_seq_length
,
hidden_dim
))
def
test_linear_mixing_layer
(
self
):
batch_size
=
2
max_seq_length
=
4
hidden_dim
=
3
inputs
=
tf
.
ones
((
batch_size
,
max_seq_length
,
hidden_dim
),
dtype
=
tf
.
float32
)
outputs
=
mixing
.
LinearTransformLayer
(
kernel_initializer
=
tf
.
keras
.
initializers
.
Ones
())(
query
=
inputs
,
value
=
inputs
)
# hidden_dim * (max_seq_length * 1) = 12.
expected_outputs
=
[
[
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
],
[
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
],
]
np
.
testing
.
assert_allclose
(
outputs
,
expected_outputs
,
rtol
=
1e-6
,
atol
=
1e-6
)
def
test_pick_fourier_transform
(
self
):
# Ensure we don't hit an edge case which exceeds the fixed numerical error.
tf
.
random
.
set_seed
(
1
)
np
.
random
.
seed
(
1
)
batch_size
=
3
max_seq_length
=
4
hidden_dim
=
8
fft
=
mixing
.
_pick_fourier_transform
(
use_fft
=
True
,
max_seq_length
=
max_seq_length
,
hidden_dim
=
hidden_dim
)
dft_matmul
=
mixing
.
_pick_fourier_transform
(
use_fft
=
False
,
max_seq_length
=
max_seq_length
,
hidden_dim
=
hidden_dim
)
inputs
=
tf
.
random
.
uniform
([
batch_size
,
max_seq_length
,
hidden_dim
])
inputs
=
tf
.
cast
(
inputs
,
tf
.
complex64
)
np
.
testing
.
assert_allclose
(
fft
(
inputs
),
dft_matmul
(
inputs
),
rtol
=
1e-6
,
atol
=
1e-6
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/mobile_bert_layers.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,6 +15,8 @@
"""MobileBERT embedding and transformer layers."""
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
on_device_embedding
from
official.nlp.modeling.layers
import
position_embedding
...
...
@@ -24,7 +26,7 @@ class NoNorm(tf.keras.layers.Layer):
"""Apply element-wise linear transformation to the last dimension."""
def
__init__
(
self
,
name
=
None
):
super
(
NoNorm
,
self
).
__init__
(
name
=
name
)
super
().
__init__
(
name
=
name
)
def
build
(
self
,
shape
):
kernal_size
=
shape
[
-
1
]
...
...
@@ -96,7 +98,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
dropout_rate: Dropout rate.
**kwargs: keyword arguments.
"""
super
(
MobileBertEmbedding
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
word_vocab_size
=
word_vocab_size
self
.
word_embed_size
=
word_embed_size
self
.
type_vocab_size
=
type_vocab_size
...
...
@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self
.
word_embedding
=
on_device_embedding
.
OnDeviceEmbedding
(
self
.
word_vocab_size
,
self
.
word_embed_size
,
initializer
=
initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'word_embedding'
)
self
.
type_embedding
=
on_device_embedding
.
OnDeviceEmbedding
(
self
.
type_vocab_size
,
self
.
output_embed_size
,
initializer
=
initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'type_embedding'
)
self
.
pos_embedding
=
position_embedding
.
PositionEmbedding
(
max_length
=
max_sequence_length
,
initializer
=
initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'position_embedding'
)
self
.
word_embedding_proj
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
word_embedding_proj
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
output_shape
=
[
None
,
self
.
output_embed_size
],
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
bias_axes
=
'd'
,
name
=
'embedding_projection'
)
self
.
layer_norm
=
_get_norm_layer
(
normalization_type
,
'embedding_norm'
)
...
...
@@ -220,7 +222,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
super
(
MobileBertTransformer
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
...
...
@@ -242,11 +244,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self
.
block_layers
=
{}
# add input bottleneck
dense_layer_2d
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
dense_layer_2d
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
output_shape
=
[
None
,
self
.
intra_bottleneck_size
],
bias_axes
=
'd'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'bottleneck_input/dense'
)
layer_norm
=
_get_norm_layer
(
self
.
normalization_type
,
name
=
'bottleneck_input/norm'
)
...
...
@@ -254,11 +256,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm
]
if
self
.
key_query_shared_bottleneck
:
dense_layer_2d
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
dense_layer_2d
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
output_shape
=
[
None
,
self
.
intra_bottleneck_size
],
bias_axes
=
'd'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'kq_shared_bottleneck/dense'
)
layer_norm
=
_get_norm_layer
(
self
.
normalization_type
,
name
=
'kq_shared_bottleneck/norm'
)
...
...
@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
value_dim
=
attention_head_size
,
dropout
=
self
.
attention_probs_dropout_prob
,
output_shape
=
self
.
intra_bottleneck_size
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'attention'
)
layer_norm
=
_get_norm_layer
(
self
.
normalization_type
,
name
=
'attention/norm'
)
...
...
@@ -284,19 +286,19 @@ class MobileBertTransformer(tf.keras.layers.Layer):
for
ffn_layer_idx
in
range
(
self
.
num_feedforward_networks
):
layer_prefix
=
f
'ffn_layer_
{
ffn_layer_idx
}
'
layer_name
=
layer_prefix
+
'/intermediate_dense'
intermediate_layer
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
intermediate_layer
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
activation
=
self
.
intermediate_act_fn
,
output_shape
=
[
None
,
self
.
intermediate_size
],
bias_axes
=
'd'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
layer_name
)
layer_name
=
layer_prefix
+
'/output_dense'
output_layer
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_layer
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
output_shape
=
[
None
,
self
.
intra_bottleneck_size
],
bias_axes
=
'd'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
layer_name
)
layer_name
=
layer_prefix
+
'/norm'
layer_norm
=
_get_norm_layer
(
self
.
normalization_type
,
...
...
@@ -306,12 +308,12 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm
])
# add output bottleneck
bottleneck
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
bottleneck
=
tf
.
keras
.
layers
.
EinsumDense
(
'abc,cd->abd'
,
output_shape
=
[
None
,
self
.
hidden_size
],
activation
=
None
,
bias_axes
=
'd'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'bottleneck_output/dense'
)
dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
self
.
hidden_dropout_prob
,
...
...
@@ -445,6 +447,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
output_weights_use_proj
=
False
,
**
kwargs
):
"""Class initialization.
...
...
@@ -455,9 +458,12 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
uniform initializer.
output: The output style for this layer. Can be either `logits` or
`predictions`.
output_weights_use_proj: Use projection instead of concating extra output
weights, this may reduce the MLM task accuracy but will reduce the model
params as well.
**kwargs: keyword arguments.
"""
super
(
MobileBertMaskedLM
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
embedding_table
=
embedding_table
self
.
activation
=
activation
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
...
...
@@ -467,6 +473,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
self
.
_output_weights_use_proj
=
output_weights_use_proj
def
build
(
self
,
input_shape
):
self
.
_vocab_size
,
embedding_width
=
self
.
embedding_table
.
shape
...
...
@@ -474,15 +481,22 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
hidden_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
)
,
name
=
'transform/dense'
)
if
hidden_size
>
embedding_width
:
self
.
extra_output_weights
=
self
.
add_weight
(
'extra_output_weights'
,
shape
=
(
self
.
_vocab_size
,
hidden_size
-
embedding_width
),
initializer
=
self
.
initializer
,
trainable
=
True
)
if
self
.
_output_weights_use_proj
:
self
.
extra_output_weights
=
self
.
add_weight
(
'output_weights_proj'
,
shape
=
(
embedding_width
,
hidden_size
),
initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
),
trainable
=
True
)
else
:
self
.
extra_output_weights
=
self
.
add_weight
(
'extra_output_weights'
,
shape
=
(
self
.
_vocab_size
,
hidden_size
-
embedding_width
),
initializer
=
tf_utils
.
clone_initializer
(
self
.
initializer
),
trainable
=
True
)
elif
hidden_size
==
embedding_width
:
self
.
extra_output_weights
=
None
else
:
...
...
@@ -507,10 +521,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
if
self
.
extra_output_weights
is
None
:
lm_data
=
tf
.
matmul
(
lm_data
,
self
.
embedding_table
,
transpose_b
=
True
)
else
:
lm_data
=
tf
.
matmul
(
lm_data
,
tf
.
concat
([
self
.
embedding_table
,
self
.
extra_output_weights
],
axis
=
1
),
transpose_b
=
True
)
if
self
.
_output_weights_use_proj
:
lm_data
=
tf
.
matmul
(
lm_data
,
self
.
extra_output_weights
,
transpose_b
=
True
)
lm_data
=
tf
.
matmul
(
lm_data
,
self
.
embedding_table
,
transpose_b
=
True
)
else
:
lm_data
=
tf
.
matmul
(
lm_data
,
tf
.
concat
([
self
.
embedding_table
,
self
.
extra_output_weights
],
axis
=
1
),
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
lm_data
,
self
.
bias
)
masked_positions_length
=
masked_positions
.
shape
.
as_list
()[
1
]
or
tf
.
shape
(
...
...
official/nlp/modeling/layers/mobile_bert_layers_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/moe.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mixture of Experts layers and their routing mechanisms."""
import
dataclasses
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_InitializerType
=
tf
.
keras
.
initializers
.
Initializer
_DEFAULT_KERNEL_INITIALIZER
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
2e-2
)
_DEFAULT_BIAS_INITIALIZER
=
tf
.
keras
.
initializers
.
Zeros
()
################## Routers (gating functions) ##################
def
_router_z_loss
(
router_logits
:
tf
.
Tensor
)
->
float
:
"""Computes router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float32>[num_groups, tokens_per_group, num_experts] router
logits.
Returns:
Scalar router z-loss <float32>.
"""
num_groups
,
tokens_per_group
,
_
=
router_logits
.
shape
log_z
=
tf
.
math
.
reduce_logsumexp
(
router_logits
,
axis
=-
1
)
z_loss
=
log_z
**
2
return
tf
.
math
.
reduce_sum
(
z_loss
)
/
(
num_groups
*
tokens_per_group
)
@
dataclasses
.
dataclass
class
RouterMask
:
"""Dispatch and combine arrays for expert routing with masked matmuls.
Attributes:
dispatch_mask:
<float>[num_groups, tokens_per_group, num_experts, expert_capacity]
dispatch array that is 1 if the token gets routed to the
corresponding expert, and 0 otherwise.
combine_array:
<float>[num_groups, tokens_per_group, num_experts, expert_capacity]
combine array used for combining expert outputs and
scaling with router probability.
"""
dispatch_mask
:
tf
.
Tensor
combine_array
:
tf
.
Tensor
RouterOutput
=
RouterMask
class
Router
(
tf
.
keras
.
layers
.
Layer
):
"""Abstract base router class, defining router API and inner workings.
Computations are performed in float32 for stability, and returned after
conversion according to the precision policy. See the discussion of
"selective precision" in https://arxiv.org/abs/2101.03961.
Uses Keras add_loss() and add_metric() APIs.
Attributes:
num_experts: Number of experts, used to check consistency with
FeedForwardExperts.
jitter_noise: Amplitude of jitter noise applied to router logits.
router_weights: Dense layer that computes logits for all tokens, which are
then used as expert or token weights.
"""
def
__init__
(
self
,
num_experts
:
int
,
*
,
jitter_noise
:
float
=
0.0
,
use_bias
:
bool
=
True
,
kernel_initializer
:
_InitializerType
=
_DEFAULT_KERNEL_INITIALIZER
,
bias_initializer
:
_InitializerType
=
_DEFAULT_BIAS_INITIALIZER
,
name
:
str
=
"router"
,
dtype
:
Any
=
tf
.
float32
,
**
kwargs
):
"""Init.
Args:
num_experts: Number of experts.
jitter_noise: Amplitude of jitter noise applied to router logits.
use_bias: Whether or not to use the bias term in computing the router
weights.
kernel_initializer: Kernel initializer for router weights.
bias_initializer: Bias initializer for router weights.
name: Layer name.
dtype: The dtype of the layer's computations and weights. tf.float32 is
recommended for stability.
**kwargs: Forwarded to super.
"""
super
().
__init__
(
name
=
name
,
dtype
=
dtype
,
**
kwargs
)
self
.
num_experts
=
num_experts
# Used to check consistency with
# FeedForwardExperts.
self
.
jitter_noise
=
jitter_noise
self
.
router_weights
=
tf
.
keras
.
layers
.
Dense
(
num_experts
,
use_bias
=
use_bias
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
bias_initializer
),
name
=
"router_weights"
,
dtype
=
dtype
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
*
,
expert_capacity
:
int
,
training
:
Optional
[
bool
]
=
None
)
->
RouterOutput
:
"""Computes dispatch and combine arrays for routing to experts.
Args:
inputs: Inputs to send to experts of shape
<float>[num_groups, tokens_per_group, hidden_dim].
expert_capacity: Each group will send this many tokens to each expert.
training: If true, apply jitter noise during routing. If not provided
taken from tf.keras.backend.
Returns:
Router indices or mask arrays (depending on router type).
"""
if
training
is
None
:
training
=
tf
.
keras
.
backend
.
learning_phase
()
# inputs shape <float>[num_groups, tokens_per_group, hidden_dim]
router_probs
,
router_logits
=
self
.
_compute_router_probabilities
(
inputs
,
apply_jitter
=
training
)
# router_probs <float32>[num_groups, tokens_per_group, num_experts]
# router_logits <float>[num_groups, tokens_per_group, num_experts]
router_z_loss
=
_router_z_loss
(
router_logits
)
self
.
add_loss
(
router_z_loss
)
self
.
add_metric
(
router_z_loss
,
name
=
"router_z_loss"
)
routing_instructions
=
self
.
_compute_routing_instructions
(
router_probs
,
expert_capacity
)
return
routing_instructions
def
_compute_router_probabilities
(
self
,
inputs
:
tf
.
Tensor
,
apply_jitter
:
bool
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Computes router probabilities from input tokens.
Args:
inputs: Inputs from which router probabilities are computed, shape
<float>[num_groups, tokens_per_group, hidden_dim].
apply_jitter: If true, apply jitter noise.
Returns:
- <float32>[num_groups, tokens_per_group, num_experts] probabilities for
each token and expert. Used for routing tokens to experts.
- <float32>[num_groups, tokens_per_group, num_experts] raw router logits.
Used for computing router z-loss.
"""
if
apply_jitter
and
self
.
jitter_noise
>
0
:
inputs
*=
tf
.
random
.
uniform
(
inputs
.
shape
,
minval
=
1.0
-
self
.
jitter_noise
,
maxval
=
1.0
+
self
.
jitter_noise
,
dtype
=
inputs
.
dtype
)
# inputs <float>, router_logits <float32>
router_logits
=
self
.
router_weights
(
inputs
)
router_probs
=
tf
.
keras
.
activations
.
softmax
(
router_logits
,
axis
=-
1
)
return
router_probs
,
router_logits
def
_compute_routing_instructions
(
self
,
router_probs
:
tf
.
Tensor
,
expert_capacity
:
int
)
->
RouterOutput
:
"""Computes instructions for routing inputs to experts."""
raise
NotImplementedError
(
"Router is an abstract class that should be subclassed."
)
class
MaskedRouter
(
Router
):
"""Abstract base router class for masked matmul dispatch routers.
MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine
array for sending and receiving (via masked matmuls) inputs and outputs to and
from experts.
Routing using masked matmuls is generally faster than scatter-based routing on
TPUs.
Uses Keras add_loss() and add_metric() APIs.
"""
def
_compute_routing_instructions
(
self
,
router_probs
:
tf
.
Tensor
,
expert_capacity
:
int
)
->
RouterMask
:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
expert_capacity: Each group will send this many tokens to each expert.
Returns:
Router mask arrays.
"""
raise
NotImplementedError
(
"MaskedRouter is an abstract class that should be subclassed."
)
class
ExpertsChooseMaskedRouter
(
MaskedRouter
):
"""Masked matmul router using experts choose tokens assignment.
This router uses the same mechanism as in Mixture-of-Experts with Expert
Choice (https://arxiv.org/abs/2202.09368): each expert selects its top
expert_capacity tokens. An individual token may be processed by multiple
experts or none at all.
Note: "experts choose routing" should not be used in decoder blocks because it
breaks the autoregressive behavior, leading to a mismatch between training
(teacher forcing) and inference (autoregressive decoding).
Uses Keras add_loss() and add_metric() APIs.
"""
def
_compute_routing_instructions
(
self
,
router_probs
:
tf
.
Tensor
,
expert_capacity
:
int
)
->
RouterMask
:
"""Computes masks for the highest probability token per expert.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
expert_capacity: Each group will send this many tokens to each expert.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
num_groups
,
tokens_per_group
,
_
=
router_probs
.
shape
router_probs_t
=
tf
.
transpose
(
router_probs
,
perm
=
[
0
,
2
,
1
])
# router_probs_t: <float32>[num_groups, num_experts, tokens_per_group]
# Top expert_capacity router probability and corresponding token indices for
# each expert.
# Shapes [num_groups, num_experts, expert_capacity]
expert_gate
,
expert_index
=
tf
.
math
.
top_k
(
router_probs_t
,
k
=
expert_capacity
,
sorted
=
False
)
# Convert to one-hot mask of expert indices for each token in each group.
# Shape: [num_groups, num_experts, expert_capacity, tokens_per_group].
dispatch_mask
=
tf
.
one_hot
(
expert_index
,
tokens_per_group
,
dtype
=
router_probs
.
dtype
)
# Move axes to conform with shape expected by MoeLayer API.
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]
dispatch_mask
=
tf
.
transpose
(
dispatch_mask
,
perm
=
[
0
,
3
,
1
,
2
])
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities.
# Shape: [num_groups, num_experts, tokens_per_group, expert_capacity]
combine_array
=
tf
.
einsum
(
"...ec,...tec->...tec"
,
expert_gate
,
dispatch_mask
)
# Add load balancing loss.
# Each expert is choosing tokens until it reaches full capacity, so we don't
# need an auxiliary loading balancing loss for expert choice routing.
self
.
add_metric
(
0.0
,
name
=
"load_balancing_loss"
)
# Gather expert metrics.
# Number of tokens that were dispatched to at least one expert.
num_tokens
=
num_groups
*
tokens_per_group
num_tokens_dispatched_somewhere
=
tf
.
math
.
reduce_sum
(
tf
.
math
.
reduce_max
(
dispatch_mask
,
axis
=
(
-
1
,
-
2
)))
fraction_tokens_left_behind
=
1.0
-
num_tokens_dispatched_somewhere
/
float
(
num_tokens
)
# Total number of tokens that were dispatched (one token could be
# dispatched to multiple experts).
num_tokens_dispatched
=
tf
.
math
.
reduce_sum
(
dispatch_mask
)
# Of the tokens dispatched, how confident was the router in its routing?
router_confidence
=
tf
.
math
.
reduce_sum
(
combine_array
)
/
num_tokens_dispatched
expert_usage
=
1.0
# Experts fully utilized when "expert choose tokens"
self
.
add_metric
(
fraction_tokens_left_behind
,
name
=
"fraction_tokens_left_behind"
)
self
.
add_metric
(
router_confidence
,
name
=
"router_confidence"
)
self
.
add_metric
(
expert_usage
,
name
=
"expert_usage"
)
# Return to default dtype now that router computation is complete.
dtype
=
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
dispatch_mask
=
tf
.
cast
(
dispatch_mask
,
dtype
)
combine_array
=
tf
.
cast
(
combine_array
,
dtype
)
output
=
RouterMask
(
dispatch_mask
,
combine_array
)
return
output
################## Model layers ##################
class
FeedForward
(
tf
.
keras
.
layers
.
Layer
):
"""Feed-forward layer - position independent, dense, nonlinear transformation.
Typically used in an MLP Transformer block.
"""
def
__init__
(
self
,
d_ff
:
int
,
*
,
dropout_rate
:
float
=
0.1
,
activation
:
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]
=
tf
.
keras
.
activations
.
gelu
,
kernel_initializer
:
_InitializerType
=
_DEFAULT_KERNEL_INITIALIZER
,
bias_initializer
:
_InitializerType
=
_DEFAULT_BIAS_INITIALIZER
,
name
:
str
=
"feed_forward"
,
**
kwargs
):
"""Initializes layer.
Args:
d_ff: Dimension of feed-forward layer.
dropout_rate: The dropout probability.
activation: (Nonlinear) transform applied in layer.
kernel_initializer: Initialization scheme for kernel.
bias_initializer: Initialization scheme for bias.
name: Layer name.
**kwargs: Forwarded to super.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
activation
=
activation
self
.
kernel_initializer
=
kernel_initializer
self
.
bias_initializer
=
bias_initializer
self
.
intermediate_layer
=
tf
.
keras
.
layers
.
Dense
(
d_ff
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
bias_initializer
),
name
=
"intermediate"
)
self
.
dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
dropout_rate
)
def
build
(
self
,
input_shape
:
Tuple
[
int
,
int
,
int
]):
"""Creates the input shape dependent output weight variables."""
self
.
output_layer
=
tf
.
keras
.
layers
.
Dense
(
input_shape
[
-
1
],
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
bias_initializer
),
name
=
"output"
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
*
,
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
"""Applies layer to inputs.
Args:
inputs: Batch of input embeddings, of shape
<float>[batch_size, seq_len, hidden_dim].
training: Only apply dropout during training.
Returns:
Transformed inputs with the same shape as inputs
<float>[batch_size, seq_len, hidden_dim].
"""
x
=
self
.
intermediate_layer
(
inputs
)
x
=
self
.
activation
(
x
)
x
=
self
.
output_layer
(
x
)
x
=
self
.
dropout_layer
(
x
,
training
=
training
)
return
x
class
FeedForwardExperts
(
tf
.
keras
.
layers
.
Layer
):
"""Feed-forward layer with multiple experts.
Note that call() takes inputs with shape
[num_groups, num_experts, expert_capacity, hidden_dim]
which is different from the usual [batch_size, seq_len, hidden_dim] used by
the FeedForward layer.
The experts are independent FeedForward layers of the
same shape, i.e. the kernel doesn't have shape [hidden_dim, out_dim], but
[num_experts, hidden_dim, out_dim].
"""
def
__init__
(
self
,
num_experts
:
int
,
d_ff
:
int
,
*
,
dropout_rate
:
float
=
0.1
,
activation
:
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]
=
tf
.
keras
.
activations
.
gelu
,
kernel_initializer
:
_InitializerType
=
_DEFAULT_KERNEL_INITIALIZER
,
bias_initializer
:
_InitializerType
=
_DEFAULT_BIAS_INITIALIZER
,
name
:
str
=
"experts"
,
**
kwargs
):
"""Initializes layer.
Args:
num_experts: Number of experts (i.e. number of independent feed-forward
blocks).
d_ff: Dimension of feed-forward layer of each expert.
dropout_rate: The dropout probability (expert_dropout_rate).
activation: (Nonlinear) transform applied in layer.
kernel_initializer: Initialization scheme for kernel.
bias_initializer: Initialization scheme for bias.
name: Layer name.
**kwargs: Forwarded to super.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
num_experts
=
num_experts
self
.
activation
=
activation
self
.
kernel_initializer
=
kernel_initializer
self
.
bias_initializer
=
bias_initializer
self
.
intermediate_layer
=
tf
.
keras
.
layers
.
EinsumDense
(
"gech,ehf->gecf"
,
output_shape
=
(
self
.
num_experts
,
None
,
d_ff
),
bias_axes
=
"ef"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
bias_initializer
),
name
=
"intermediate"
)
self
.
dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
dropout_rate
)
def
build
(
self
,
input_shape
:
Tuple
[
int
,
int
,
int
,
int
]):
"""Creates the input shape dependent output weight variables."""
if
input_shape
[
1
]
!=
self
.
num_experts
:
raise
ValueError
(
f
"Input shape
{
input_shape
}
is inconsistent with num_experts "
f
"
{
self
.
num_experts
}
."
)
self
.
output_layer
=
tf
.
keras
.
layers
.
EinsumDense
(
"gecf,efh->gech"
,
output_shape
=
(
self
.
num_experts
,
None
,
input_shape
[
-
1
]),
bias_axes
=
"eh"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
bias_initializer
),
name
=
"output"
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
*
,
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
"""Applies layer to inputs.
Args:
inputs: Inputs of shape
<float>[num_groups, num_experts, expert_capacity, hidden_dim].
training: Only apply dropout during training.
Returns:
Transformed inputs with the same shape as inputs
<float>[num_groups, num_experts, expert_capacity, hidden_dim].
"""
x
=
self
.
intermediate_layer
(
inputs
)
x
=
self
.
activation
(
x
)
x
=
self
.
output_layer
(
x
)
x
=
self
.
dropout_layer
(
x
,
training
=
training
)
return
x
class
MoeLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Sparse MoE layer with per-token routing.
In this TF implementation, all experts need to fit onto a single device
allowing for batch parallelism only.
Uses Keras add_loss() and add_metric() APIs.
Attributes:
num_experts: Number of experts (i.e. number of independent feed-forward
blocks).
"""
def
__init__
(
self
,
experts
:
FeedForwardExperts
,
router
:
MaskedRouter
,
*
,
train_capacity_factor
:
float
=
1.0
,
eval_capacity_factor
:
float
=
1.0
,
min_expert_capacity
:
int
=
4
,
max_group_size
:
int
=
4096
,
strict_group_size
:
bool
=
False
,
name
:
str
=
"moe"
,
**
kwargs
):
"""Init.
Args:
experts: Instance of FeedForwardExperts. Needs to have the same
num_experts as the router.
router: Instance of MaskedRouter to route the tokens to
the different experts.
train_capacity_factor: Scaling factor to increase the expert token
capacity during training. This factor plays an analogous, but slightly
different, role depending on the routing assignment algorithm:
- For "tokens choose" routing, the capacity factor only affects the
maximum number of tokens that an expert will process. It does not
affect how many experts a given token is routed to; see the
num_selected_experts attributes of "tokens choose" routers.
- For "experts choose" routing, because experts always fill their
buffer, increasing the capacity factor will increase the number of
tokens that an expert will process AND will indirectly increase the
number of experts that a given token is routed to.
eval_capacity_factor: As above, but used during evaluation.
min_expert_capacity: Minimum token processing capacity for each expert.
max_group_size: The total number of tokens on each device is subdivided
into groups of this size. Router computations are then performed on a
per-group basis. A larger group size will result in slower but more
accurate top-k and sorting computations, whereas a smaller group size
will result in faster but more approximate (and potentially less stable)
routing choices. Note that actual group size may be smaller than
max_group_size for consistency with the number of experts and tokens;
see also `strict_group_size` attribute. In practice,
we find that imperfect routing choices are tolerable and recommend
choosing a group size on the order of 4096 tokens, although this number
will vary based on model configuration and size.
strict_group_size: If True, fail if unable to set the token group size
equal to max_group_size. If False (default), the actual group size may
be smaller than max_group_size for consistency with the number of
experts and tokens.
name: Layer name.
**kwargs: Forwarded to super.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_experts
=
experts
self
.
_router
=
router
self
.
num_experts
=
experts
.
num_experts
assert
experts
.
num_experts
==
router
.
num_experts
self
.
_train_capacity_factor
=
train_capacity_factor
self
.
_eval_capacity_factor
=
eval_capacity_factor
self
.
_max_group_size
=
max_group_size
self
.
_min_expert_capacity
=
min_expert_capacity
self
.
_strict_group_size
=
strict_group_size
def
call
(
self
,
inputs
:
tf
.
Tensor
,
*
,
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
"""Applies MoeLayer.
Args:
inputs: Batch of input embeddings of shape
<float>[batch_size, seq_length, hidden_dim].
training: Only apply dropout and jitter noise during training. If not
provided taken from tf.keras.backend.
Returns:
Transformed inputs with same shape as inputs:
<float>[batch_size, seq_length, hidden_dim].
Raises:
ValueError if we cannot find a group_size satisfying given requirements.
"""
if
training
is
None
:
training
=
tf
.
keras
.
backend
.
learning_phase
()
# inputs shape [batch_size, seq_length, hidden_dim]
per_device_batch_size
,
seq_length
,
hidden_dim
=
inputs
.
shape
num_tokens
=
per_device_batch_size
*
seq_length
num_groups
=
self
.
_num_groups
(
num_tokens
,
self
.
_max_group_size
)
tokens_per_group
=
num_tokens
//
num_groups
if
training
:
capacity_factor
=
self
.
_train_capacity_factor
else
:
capacity_factor
=
self
.
_eval_capacity_factor
# Each group will send expert_capacity tokens to each expert.
expert_capacity
=
int
(
round
(
capacity_factor
*
tokens_per_group
/
self
.
num_experts
))
expert_capacity
=
max
(
expert_capacity
,
self
.
_min_expert_capacity
)
logging
.
info
(
"Selected expert_capacity=%d for num_experts=%d and training=%r."
,
expert_capacity
,
self
.
num_experts
,
training
)
# Reshape batch and sequence/token dimensions for expert routing.
x
=
tf
.
reshape
(
inputs
,
(
num_groups
,
tokens_per_group
,
hidden_dim
))
x
=
self
.
_mask_and_dispatch_to_experts
(
x
,
expert_capacity
,
training
)
# Return to original input shape.
x
=
tf
.
reshape
(
x
,
(
per_device_batch_size
,
seq_length
,
hidden_dim
))
return
x
def
_num_groups
(
self
,
num_tokens
:
int
,
max_group_size
:
int
)
->
int
:
"""Returns the number of token routing groups.
Note that the quantities are local to the device.
We select the smallest num_groups such that:
- num_groups >= num_tokens / max_group_size (ensuring the group size is no
larger than max_group_size),
- num_tokens % num_groups = 0 (ensuring that the group size evenly divides
into the num_tokens),
Args:
num_tokens: Number of tokens from input batch.
max_group_size: Maximum size of each token routing group. Actual group
size may end up being smaller unless strict_group_size==True.
Returns:
Number of token routing groups.
Raises:
ValueError if we cannot find a group_size satisfying the above
requirements.
"""
# Increase the number of groups (and decrease the group size) until we have
# a viable number of groups.
min_num_groups
=
int
(
np
.
ceil
(
num_tokens
/
max_group_size
))
num_groups
=
min_num_groups
while
num_groups
<
num_tokens
and
num_tokens
%
num_groups
!=
0
:
num_groups
+=
1
group_size
=
num_tokens
//
num_groups
logging
.
info
(
"Selected group_size=%d and num_groups=%d for input num_tokens=%d, "
"max_group_size=%d, num_experts=%d."
,
group_size
,
num_groups
,
num_tokens
,
max_group_size
,
self
.
num_experts
)
if
group_size
<
self
.
_min_expert_capacity
:
raise
ValueError
(
f
"Local (per-device) group_size
{
group_size
}
is smaller than "
f
"min_expert_capacity
{
self
.
_min_expert_capacity
}
, which is probably "
"not intended. Please increase max_group_size {max_group_size} to"
" seq_length or increase batch_size or decrease min_expert_capacity."
)
if
self
.
_strict_group_size
and
group_size
!=
self
.
_max_group_size
:
raise
ValueError
(
f
"Selected group_size=
{
group_size
}
is less than the "
f
"max_group_size=
{
max_group_size
}
. Exiting because strict mode is "
"active (strict_group_size=True)"
)
return
num_groups
def
_mask_and_dispatch_to_experts
(
self
,
inputs
:
tf
.
Tensor
,
expert_capacity
:
int
,
training
:
bool
)
->
tf
.
Tensor
:
"""Wraps expert masked routing and dispatching algorithm.
This algorithm takes the following steps:
(1) Compute dispatch mask and combine array using self._router.
(2) Dispatch inputs to experts based on dispatch mask.
(3) Recombine individual expert outputs using combine array.
Args:
inputs: <float>[num_groups, tokens_per_group, hidden_dim] inputs to
send to experts.
expert_capacity: Each group will send this many tokens to each expert.
training: If true, apply jitter noise during routing and dropout
during expert computation.
Returns:
<float>[num_groups, num_tokens_per_group, hidden_dim] outputs from
experts.
"""
# Shape [num_groups, tokens_per_group, num_experts, expert_capacity]
router_mask
=
self
.
_router
(
inputs
,
expert_capacity
=
expert_capacity
,
training
=
training
)
# Shape [num_groups, num_experts, expert_capacity, hidden_dim]
expert_inputs
=
tf
.
einsum
(
"gth,gtec->gech"
,
inputs
,
router_mask
.
dispatch_mask
)
expert_outputs
=
self
.
_experts
(
expert_inputs
,
training
=
training
)
# Shape [num_groups, tokens_per_group, hidden_dim]
combined_outputs
=
tf
.
einsum
(
"gech,gtec->gth"
,
expert_outputs
,
router_mask
.
combine_array
)
return
combined_outputs
class
MoeLayerWithBackbone
(
tf
.
keras
.
layers
.
Layer
):
"""Sparse MoE layer plus a FeedForward layer evaluated for all tokens.
Uses Keras add_loss() and add_metric() APIs.
"""
def
__init__
(
self
,
moe
:
MoeLayer
,
backbone_d_ff
:
int
,
*
,
dropout_rate
:
float
=
0.1
,
activation
:
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]
=
tf
.
keras
.
activations
.
gelu
,
kernel_initializer
:
_InitializerType
=
_DEFAULT_KERNEL_INITIALIZER
,
bias_initializer
:
_InitializerType
=
_DEFAULT_BIAS_INITIALIZER
,
name
:
str
=
"moe_with_backbone"
,
**
kwargs
):
"""Init.
Args:
moe: Instance of MoeLayer with experts and router.
backbone_d_ff: Dimension of feed-forward layer of a lightweight backbone,
which is evaluated for all tokens.
dropout_rate: Dropout rate for the backbone.
activation: (Nonlinear) transform applied in the backbone.
kernel_initializer: Initialization scheme for kernels in the backbone.
bias_initializer: Initialization scheme for biases in the backbone.
name: Layer name.
**kwargs: Forwarded to super.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_moe
=
moe
self
.
_backbone
=
FeedForward
(
backbone_d_ff
,
dropout_rate
=
dropout_rate
,
activation
=
activation
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
bias_initializer
),
name
=
"backbone"
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
*
,
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
"""Applies MoeLayerWithBackbone layer.
Args:
inputs: Batch of input embeddings of shape
<float>[batch_size, seq_length, hidden_dim].
training: Only apply dropout and jitter noise during training. If not
provided taken from tf.keras.backend.
Returns:
Transformed inputs with same shape as inputs:
<float>[batch_size, seq_length, hidden_dim].
"""
return
self
.
_backbone
(
inputs
,
training
=
training
)
+
self
.
_moe
(
inputs
,
training
=
training
)
official/nlp/modeling/layers/moe_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for moe.py."""
import
ml_collections
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
moe
def
small_config
()
->
ml_collections
.
ConfigDict
:
"""Creates a small model config that can be used by all tests."""
config
=
ml_collections
.
ConfigDict
()
config
.
d_ff
=
32
config
.
dropout_rate
=
0.1
config
.
num_experts
=
2
config
.
expert_d_ff
=
33
config
.
expert_dropout_rate
=
0.1
config
.
jitter_noise
=
0.1
config
.
train_capacity_factor
=
1.0
config
.
eval_capacity_factor
=
1.0
config
.
min_expert_capacity
=
1
config
.
max_group_size
=
9
config
.
backbone_d_ff
=
13
return
config
def
make_input_ones
(
batch_size
:
int
=
2
,
seq_length
:
int
=
10
,
hidden_dim
:
int
=
7
)
->
tf
.
Tensor
:
return
tf
.
ones
((
batch_size
,
seq_length
,
hidden_dim
),
dtype
=
tf
.
float32
)
def
make_experts_input_ones
(
num_groups
:
int
=
1
,
num_experts
:
int
=
2
,
expert_capacity
:
int
=
5
,
hidden_dim
:
int
=
7
)
->
tf
.
Tensor
:
return
tf
.
ones
((
num_groups
,
num_experts
,
expert_capacity
,
hidden_dim
),
dtype
=
tf
.
float32
)
class
MoeTest
(
tf
.
test
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
def
test_router_z_loss_dtype
(
self
):
x
=
tf
.
constant
([[[
10.0
,
5.0
]]],
dtype
=
tf
.
float32
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
(
5
+
np
.
log
(
np
.
exp
(
5
)
+
1
))
**
2
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
x
=
tf
.
constant
([[[
10.0
,
5.0
]]],
dtype
=
tf
.
bfloat16
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
100.0
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
def
test_router_z_loss_shape
(
self
):
x
=
make_input_ones
(
2
,
5
,
7
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
(
np
.
log
(
7
)
+
1
)
**
2
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
def
test_experts_choose_masked_router_dtype_shape
(
self
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_bfloat16'
)
num_groups
=
2
tokens_per_group
=
3
hidden_dim
=
tokens_per_group
num_experts
=
tokens_per_group
expert_capacity
=
2
x
=
np
.
zeros
([
num_groups
,
tokens_per_group
,
hidden_dim
])
x
[
0
,
0
,
0
]
+=
1
x
[
0
,
:
2
,
:
2
]
+=
1
x
[
1
,
1
:,
1
:]
+=
1
x
[
1
,
-
1
,
-
1
]
+=
1
router
=
moe
.
ExpertsChooseMaskedRouter
(
num_experts
=
num_experts
,
jitter_noise
=
0.1
,
use_bias
=
True
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'identity'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
router_mask
=
router
(
x
,
expert_capacity
=
expert_capacity
,
training
=
False
)
self
.
assertDTypeEqual
(
router_mask
.
dispatch_mask
,
tf
.
bfloat16
)
self
.
assertDTypeEqual
(
router_mask
.
combine_array
,
tf
.
bfloat16
)
expect_shape
=
[
num_groups
,
tokens_per_group
,
num_experts
,
expert_capacity
]
self
.
assertEqual
(
expect_shape
,
router_mask
.
dispatch_mask
.
shape
)
self
.
assertEqual
(
expect_shape
,
router_mask
.
combine_array
.
shape
)
# top_k call may not be sorted, so can't compare the output directly
# Check that the output contains only 0s and 1s
out_dm
=
router_mask
.
dispatch_mask
.
numpy
()
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm
.
flatten
().
astype
(
np
.
int32
)))
# Check that the right tokens for selected
out_dm_indices
=
np
.
dot
(
out_dm
.
transpose
((
0
,
2
,
3
,
1
)),
np
.
arange
(
tokens_per_group
))
# Shape [num_groups, num_experts, expert_capacity]
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
0
,
0
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
0
,
1
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
0
,
2
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
1
,
0
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
1
,
1
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
1
,
2
,
:].
astype
(
np
.
int32
)))
out_ca
=
router_mask
.
combine_array
.
numpy
()
out_ca
=
np
.
dot
(
out_ca
,
np
.
ones
((
expert_capacity
,)))
expected_combine_array
=
np
.
array
(
[[[
0.66
,
0.0
,
0.0
],
[
0.42
,
0.42
,
0.16
],
[
0.0
,
0.33
,
0.33
]],
[[
0.33
,
0.33
,
0.0
],
[
0.16
,
0.42
,
0.42
],
[
0.0
,
0.0
,
0.66
]]])
self
.
assertAllClose
(
expected_combine_array
,
out_ca
,
atol
=
1e-2
)
def
test_feed_forward_shape_and_vars
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForward
(
d_ff
=
config
.
d_ff
,
dropout_rate
=
config
.
dropout_rate
)
inputs
=
make_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
var_names
=
sorted
([
v
.
name
for
v
in
layer
.
trainable_variables
])
self
.
assertAllEqual
([
'feed_forward/intermediate/bias:0'
,
'feed_forward/intermediate/kernel:0'
,
'feed_forward/output/bias:0'
,
'feed_forward/output/kernel:0'
],
var_names
)
def
test_feed_forward_manual
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForward
(
d_ff
=
config
.
d_ff
,
dropout_rate
=
config
.
dropout_rate
,
activation
=
tf
.
keras
.
activations
.
relu
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
inputs
=
make_input_ones
(
1
,
2
,
3
)
outputs
=
layer
(
inputs
,
training
=
False
)
manual_outputs
=
tf
.
constant
([[[
129.0
,
129.0
,
129.0
],
[
129.0
,
129.0
,
129.0
]]])
self
.
assertAllClose
(
manual_outputs
,
outputs
,
atol
=
1e-7
)
def
test_feed_forward_experts_shape_and_vars
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
inputs
=
make_experts_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
var_names
=
sorted
([
v
.
name
for
v
in
layer
.
trainable_variables
])
self
.
assertAllEqual
([
'experts/intermediate/bias:0'
,
'experts/intermediate/kernel:0'
,
'experts/output/bias:0'
,
'experts/output/kernel:0'
],
var_names
)
def
test_feed_forward_experts_manual
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForwardExperts
(
num_experts
=
1
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
,
activation
=
tf
.
keras
.
activations
.
relu
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
inputs
=
make_experts_input_ones
(
1
,
1
,
2
,
3
)
outputs
=
layer
(
inputs
,
training
=
False
)
manual_outputs
=
tf
.
constant
([[[[
133.0
,
133.0
,
133.0
],
[
133.0
,
133.0
,
133.0
]]]])
self
.
assertAllClose
(
manual_outputs
,
outputs
,
atol
=
1e-7
)
def
test_moe_layer
(
self
):
config
=
small_config
()
experts
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
router
=
moe
.
ExpertsChooseMaskedRouter
(
config
.
num_experts
,
jitter_noise
=
config
.
jitter_noise
)
moe_layer
=
moe
.
MoeLayer
(
experts
,
router
,
train_capacity_factor
=
config
.
train_capacity_factor
,
eval_capacity_factor
=
config
.
eval_capacity_factor
,
max_group_size
=
config
.
max_group_size
,
min_expert_capacity
=
config
.
min_expert_capacity
)
inputs
=
make_input_ones
()
with
self
.
assertLogs
(
'absl'
,
level
=
'INFO'
)
as
cm
:
outputs
=
moe_layer
(
inputs
,
training
=
True
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
self
.
assertEqual
(
cm
.
output
,
[
(
'INFO:absl:Selected group_size=5 and num_groups=4 for input '
'num_tokens=20, max_group_size=9, num_experts=2.'
),
(
'INFO:absl:Selected expert_capacity=2 for num_experts=2 and '
'training=True.'
)])
var_names
=
sorted
([
v
.
name
for
v
in
moe_layer
.
trainable_variables
])
self
.
assertAllEqual
([
'moe/experts/intermediate/bias:0'
,
'moe/experts/intermediate/kernel:0'
,
'moe/experts/output/bias:0'
,
'moe/experts/output/kernel:0'
,
'moe/router/router_weights/bias:0'
,
'moe/router/router_weights/kernel:0'
],
var_names
)
self
.
assertLen
(
moe_layer
.
losses
,
1
)
metrics
=
[
metric
.
name
for
metric
in
moe_layer
.
metrics
]
self
.
assertSetEqual
(
{
'router_z_loss'
,
'load_balancing_loss'
,
'fraction_tokens_left_behind'
,
'router_confidence'
,
'expert_usage'
},
set
(
metrics
))
def
test_moe_layer_with_backbone
(
self
):
config
=
small_config
()
experts
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
router
=
moe
.
ExpertsChooseMaskedRouter
(
config
.
num_experts
,
jitter_noise
=
config
.
jitter_noise
)
moe_layer
=
moe
.
MoeLayer
(
experts
,
router
,
train_capacity_factor
=
config
.
train_capacity_factor
,
eval_capacity_factor
=
config
.
eval_capacity_factor
,
max_group_size
=
config
.
max_group_size
,
min_expert_capacity
=
config
.
min_expert_capacity
)
layer
=
moe
.
MoeLayerWithBackbone
(
moe_layer
,
config
.
backbone_d_ff
)
inputs
=
make_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/layers/multi_channel_attention.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,6 +18,7 @@
import
math
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
masked_softmax
...
...
@@ -48,7 +49,7 @@ class VotingAttention(tf.keras.layers.Layer):
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
super
(
VotingAttention
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_head_size
=
head_size
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
...
...
@@ -60,26 +61,28 @@ class VotingAttention(tf.keras.layers.Layer):
def
build
(
self
,
unused_input_shapes
):
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_query_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
"BAE,ENH->BANH"
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
name
=
"query"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_key_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
"BAE,ENH->BANH"
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
name
=
"key"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
super
(
VotingAttention
,
self
).
build
(
unused_input_shapes
)
super
().
build
(
unused_input_shapes
)
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
num_docs
=
tf_utils
.
get_shape_list
(
encoder_outputs
,
expected_rank
=
[
4
])[
1
]
...
...
@@ -120,7 +123,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
"""
def
_build_attention
(
self
,
rank
):
super
(
MultiChannelAttention
,
self
).
_build_attention
(
rank
)
# pytype: disable=attribute-error # typed-keras
super
().
_build_attention
(
rank
)
# pytype: disable=attribute-error # typed-keras
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
call
(
self
,
...
...
official/nlp/modeling/layers/multi_channel_attention_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/on_device_embedding.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -47,7 +47,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
scale_factor
=
None
,
**
kwargs
):
super
(
OnDeviceEmbedding
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_embedding_width
=
embedding_width
self
.
_initializer
=
initializer
...
...
@@ -62,7 +62,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"use_one_hot"
:
self
.
_use_one_hot
,
"scale_factor"
:
self
.
_scale_factor
,
}
base_config
=
super
(
OnDeviceEmbedding
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
...
...
@@ -72,7 +72,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
initializer
=
self
.
_initializer
,
dtype
=
tf
.
float32
)
super
(
OnDeviceEmbedding
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
):
flat_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
])
...
...
official/nlp/modeling/layers/on_device_embedding_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/pack_optimization.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pack sequence optimization on accelerators."""
from
typing
import
Dict
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
rezero_transformer
from
official.nlp.modeling.layers
import
self_attention_mask
from
official.nlp.modeling.layers
import
transformer_encoder_block
from
official.nlp.modeling.layers
import
transformer_scaffold
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
PackBertEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
def
__init__
(
self
,
pack_sequences
:
int
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
pack_sequences
=
pack_sequences
def
call
(
self
,
input_embeddings
:
tf
.
Tensor
,
input_mask
:
tf
.
Tensor
)
->
Dict
[
str
,
tf
.
Tensor
]:
batch_size
,
seq_len
,
embedding_dim
=
tf_utils
.
get_shape_list
(
input_embeddings
,
expected_rank
=
3
)
reduced_batch_size
=
batch_size
//
self
.
pack_sequences
packed_seq_len
=
self
.
pack_sequences
*
seq_len
packed_embeddings
=
tf
.
reshape
(
input_embeddings
,
[
reduced_batch_size
,
packed_seq_len
,
embedding_dim
])
input_mask
=
tf
.
reshape
(
input_mask
,
[
reduced_batch_size
,
packed_seq_len
])
example_ids
=
1
+
tf
.
range
(
self
.
pack_sequences
)
# Shape: [batch_size, seq_len, pack_sequences].
example_ids
=
tf
.
tile
(
example_ids
[
None
,
:,
None
],
[
reduced_batch_size
,
1
,
seq_len
])
example_ids
=
tf
.
reshape
(
example_ids
,
[
reduced_batch_size
,
packed_seq_len
])
example_ids
=
tf
.
where
(
tf
.
math
.
equal
(
input_mask
,
0
),
tf
.
zeros_like
(
example_ids
),
example_ids
)
packing_mask
=
tf
.
cast
(
tf
.
equal
(
tf
.
expand_dims
(
example_ids
,
2
),
tf
.
expand_dims
(
example_ids
,
1
)),
dtype
=
tf
.
bool
)
attention_mask
=
self_attention_mask
.
get_mask
(
packed_embeddings
,
input_mask
,
dtype
=
tf
.
bool
)
combined_attention_mask
=
tf
.
cast
(
tf
.
math
.
logical_and
(
attention_mask
,
packing_mask
),
tf
.
float32
)
return
dict
(
packed_embeddings
=
packed_embeddings
,
combined_attention_mask
=
combined_attention_mask
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedTransformerEncoderBlock
(
transformer_encoder_block
.
TransformerEncoderBlock
):
"""Transformer layer for packing optimization to stride over inputs."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
_output_range
is
not
None
:
raise
ValueError
(
'StridedTransformerEncoderBlock does not '
'support `output_range` argument.'
)
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
'Unexpected inputs to %s with length at %d'
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
::
stride
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm_kv
(
key_value
)
target_tensor
=
input_tensor
[:,
::
stride
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if
self
.
_use_query_residual
:
attention_output
=
source_tensor
+
attention_output
else
:
if
self
.
_use_query_residual
:
attention_output
=
target_tensor
+
attention_output
attention_output
=
self
.
_attention_layer_norm
(
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
layer_output
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedReZeroTransformer
(
rezero_transformer
.
ReZeroTransformer
):
"""ReZeroTransformer for packing optimization to stride over inputs."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
_output_range
is
not
None
:
raise
ValueError
(
f
'
{
self
.
__class__
}
does not '
'support `output_range` argument.'
)
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
f
'Unexpected inputs to
{
self
.
__class__
}
with '
f
'length at
{
len
(
inputs
)
}
.'
)
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
target_tensor
=
input_tensor
[:,
::
stride
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
attention_output
=
self
.
_attention_layer_norm
(
attention_output
)
else
:
attention_output
=
tf
.
cast
(
attention_output
,
tf
.
float32
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_inner_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
attention_output
+
tf
.
cast
(
self
.
_rezero_a_ffn
*
layer_output
,
tf
.
float32
)
if
self
.
_use_layer_norm
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
)
return
layer_output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedTransformerScaffold
(
transformer_scaffold
.
TransformerScaffold
):
"""TransformerScaffold for packing optimization to stride over inputs."""
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
,
training
=
None
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
'Unexpected inputs to %s with length at %d'
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
key_value
is
None
:
key_value
=
input_tensor
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
::
stride
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
,
training
=
training
)
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
target_tensor
=
input_tensor
[:,
::
stride
,
:]
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
,
training
=
training
)
attention_output
=
self
.
_attention_dropout
(
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
,
training
=
training
)
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
,
training
=
training
)
layer_output
=
self
.
_output_dropout
(
layer_output
,
training
=
training
)
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
,
training
=
training
)
else
:
if
self
.
_norm_first
:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
layer_output
+=
source_attention_output
else
:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
return
layer_output
official/nlp/modeling/layers/pack_optimization_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pack_optimization."""
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
pack_optimization
class
PackOptimizationTest
(
tf
.
test
.
TestCase
):
def
test_bert_embedding_packing
(
self
):
batch_size
,
seq_len
,
embed_dim
=
2
,
4
,
8
pack_sequences
=
2
token_and_position_embed
=
tf
.
ones
((
batch_size
,
seq_len
,
embed_dim
),
dtype
=
tf
.
float32
)
input_mask
=
tf
.
ones
((
batch_size
,
seq_len
),
dtype
=
tf
.
int32
)
layer
=
pack_optimization
.
PackBertEmbeddings
(
pack_sequences
=
pack_sequences
)
outputs
=
layer
(
token_and_position_embed
,
input_mask
)
self
.
assertEqual
(
outputs
[
"packed_embeddings"
].
shape
,
(
1
,
8
,
embed_dim
))
self
.
assertEqual
(
outputs
[
"combined_attention_mask"
].
shape
,
(
1
,
8
,
8
))
def
test_strided_transformer_encoder_block
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
transformer
=
pack_optimization
.
StridedTransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
outputs
=
transformer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
def
test_strided_rezero_transformer
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
transformer
=
pack_optimization
.
StridedReZeroTransformer
(
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
outputs
=
transformer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
def
test_strided_scaffold
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
test_layer
=
pack_optimization
.
StridedTransformerScaffold
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
"relu"
)
outputs
=
test_layer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Prev
1
…
16
17
18
19
20
21
22
23
24
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment