"...resnet50_tensorflow.git" did not exist on "fb35d6bef6eecc640fb865dd2fc73d8fee2a93b6"
Commit 4e50660b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333019996
parent db446fcd
...@@ -28,6 +28,7 @@ from official.modeling import hyperparams ...@@ -28,6 +28,7 @@ from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.projects.bigbird import encoder as bigbird_encoder
@dataclasses.dataclass @dataclasses.dataclass
...@@ -60,18 +61,18 @@ class MobileBertEncoderConfig(hyperparams.Config): ...@@ -60,18 +61,18 @@ class MobileBertEncoderConfig(hyperparams.Config):
num_blocks: number of transformer block in the encoder model. num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block. hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block. num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
forward) layer. layer.
intermediate_act_fn: the non-linear activation function to apply intermediate_act_fn: the non-linear activation function to apply to the
to the output of the intermediate/feed-forward layer. output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers. hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention attention_probs_dropout_prob: dropout probability of the attention
probabilities. probabilities.
intra_bottleneck_size: the size of bottleneck. intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
key_query_shared_bottleneck: whether to share linear transformation for key_query_shared_bottleneck: whether to share linear transformation for keys
keys and queries. and queries.
num_feedforward_networks: number of stacked feed-forward networks. num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear 'layer_norm' are supported. 'no_norm' represents the element-wise linear
...@@ -116,12 +117,32 @@ class AlbertEncoderConfig(hyperparams.Config): ...@@ -116,12 +117,32 @@ class AlbertEncoderConfig(hyperparams.Config):
initializer_range: float = 0.02 initializer_range: float = 0.02
@dataclasses.dataclass
class BigBirdEncoderConfig(hyperparams.Config):
"""BigBird encoder configuration."""
vocab_size: int = 50358
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 4096
num_rand_blocks: int = 3
block_size: int = 64
type_vocab_size: int = 16
initializer_range: float = 0.02
embedding_size: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig): class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration.""" """Encoder configuration."""
type: Optional[str] = "bert" type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig() albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
...@@ -129,6 +150,7 @@ ENCODER_CLS = { ...@@ -129,6 +150,7 @@ ENCODER_CLS = {
"bert": networks.BertEncoder, "bert": networks.BertEncoder,
"mobilebert": networks.MobileBERTEncoder, "mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertTransformerEncoder, "albert": networks.AlbertTransformerEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder,
} }
...@@ -226,6 +248,24 @@ def build_encoder( ...@@ -226,6 +248,24 @@ def build_encoder(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
dict_outputs=True) dict_outputs=True)
if encoder_type == "bigbird":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size)
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return encoder_cls( return encoder_cls(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based bigbird attention layer."""
import numpy as np
import tensorflow as tf
MAX_SEQ_LEN = 4096
def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
Returns:
float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4,
from_block_size, 3*to_block_size].
"""
exp_blocked_to_pad = tf.concat([
to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:,
3:-1]
], 2)
band_mask = tf.einsum("BLQ,BLK->BLQK", from_blocked_mask[:, 2:-2],
exp_blocked_to_pad)
band_mask = tf.expand_dims(band_mask, 1)
return band_mask
def bigbird_block_rand_mask(from_seq_length,
to_seq_length,
from_block_size,
to_block_size,
num_rand_blocks,
last_idx=-1):
"""Create adjacency list of random attention.
Args:
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
num_rand_blocks: int. Number of random chunks per row.
last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
if positive then num_rand_blocks blocks choosen only upto last_idx.
Returns:
adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
"""
assert from_seq_length//from_block_size == to_seq_length//to_block_size, \
"Error the number of blocks needs to be same!"
rand_attn = np.zeros(
(from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
last = to_seq_length // to_block_size - 1
if last_idx > (2 * to_block_size):
last = (last_idx // to_block_size) - 1
r = num_rand_blocks # shorthand
for i in range(1, from_seq_length // from_block_size - 1):
start = i - 2
end = i
if i == 1:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r]
elif i == 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r]
elif i == from_seq_length // from_block_size - 3:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
# Missing -3: should have been sliced till last-3
elif i == from_seq_length // from_block_size - 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
# Missing -4: should have been sliced till last-4
else:
if start > last:
start = last
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
elif (end + 1) == last:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
else:
rand_attn[i - 1, :] = np.random.permutation(
np.concatenate((middle_seq[:start], middle_seq[end + 1:last])))[:r]
return rand_attn
def create_rand_mask_from_inputs(from_blocked_mask, to_blocked_mask, rand_attn,
num_attention_heads, num_rand_blocks,
batch_size, from_seq_length, from_block_size):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
from_block_size: int. size of block in from sequence.
Returns:
float Tensor of shape [batch_size, num_attention_heads,
from_seq_length//from_block_size-2,
from_block_size, num_rand_blocks*to_block_size].
"""
num_windows = from_seq_length // from_block_size - 2
rand_mask = tf.reshape(
tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [
batch_size, num_attention_heads, num_windows,
num_rand_blocks * from_block_size
])
rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1],
rand_mask)
return rand_mask
def bigbird_block_sparse_attention(
query_layer, key_layer, value_layer, band_mask, from_mask, to_mask,
from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads,
num_rand_blocks, size_per_head, batch_size, from_seq_length, to_seq_length,
from_block_size, to_block_size):
"""BigBird attention sparse calculation using blocks in linear time.
Assumes from_seq_length//from_block_size == to_seq_length//to_block_size.
Args:
query_layer: float Tensor of shape [batch_size, num_attention_heads,
from_seq_length, size_per_head]
key_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
value_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
band_mask: (optional) int32 Tensor of shape [batch_size, 1,
from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. The
values should be 1 or 0. The attention scores will effectively be set to
-infinity for any positions in the mask that are 0, and will be unchanged
for positions that are 1.
from_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length,
1]. The values should be 1 or 0. The attention scores will effectively be
set to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1, to_seq_length].
The values should be 1 or 0. The attention scores will effectively be set
to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size]. Same as from_mask,
just reshaped.
to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size]. Same as to_mask, just
reshaped.
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
size_per_head: int. Size of each attention head.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
Returns:
float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
size_per_head].
"""
rand_attn = tf.expand_dims(rand_attn, 0)
rand_attn = tf.repeat(rand_attn, batch_size, 0)
rand_mask = create_rand_mask_from_inputs(
from_blocked_mask,
to_blocked_mask,
rand_attn,
num_attention_heads,
num_rand_blocks,
batch_size,
from_seq_length,
from_block_size,
)
# Define shorthands
h = num_attention_heads
r = num_rand_blocks
d = size_per_head
b = batch_size
m = from_seq_length
n = to_seq_length
wm = from_block_size
wn = to_block_size
query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
blocked_query_matrix = tf.reshape(query_layer, (b, h, m // wm, wm, -1))
blocked_key_matrix = tf.reshape(key_layer, (b, h, n // wn, wn, -1))
blocked_value_matrix = tf.reshape(value_layer, (b, h, n // wn, wn, -1))
gathered_key = tf.reshape(
tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"),
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
gathered_value = tf.reshape(
tf.gather(
blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"),
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
first_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0],
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
first_product = tf.multiply(first_product, 1.0 / np.sqrt(d))
first_product += (1.0 - tf.cast(to_mask, dtype=tf.float32)) * -10000.0
first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n]
first_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", first_attn_weights,
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
first_context_layer = tf.expand_dims(first_context_layer, 2)
second_key_mat = tf.concat([
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1],
blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :,
-1], gathered_key[:, :, 0]
], 2) # [b, h, (4+r)*wn, -1]
second_value_mat = tf.concat([
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1],
blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1],
gathered_value[:, :, 0]
], 2) # [b, h, (4+r)*wn, -1]
second_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_seq_pad = tf.concat([
to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32)
], 3)
second_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, 0]], 3)
second_product = tf.multiply(second_product, 1.0 / np.sqrt(d))
second_product += (1.0 -
tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0
second_attn_weights = tf.nn.softmax(second_product) # [b , h, wm, (4+r)*wn]
second_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", second_attn_weights, second_value_mat
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_context_layer = tf.expand_dims(second_context_layer, 2)
exp_blocked_key_matrix = tf.concat([
blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2],
blocked_key_matrix[:, :, 3:-1]
], 3) # [b, h, m//wm-4, 3*wn, -1]
exp_blocked_value_matrix = tf.concat([
blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2],
blocked_value_matrix[:, :, 3:-1]
], 3) # [b, h, m//wm-4, 3*wn, -1]
middle_query_matrix = blocked_query_matrix[:, :, 2:-2]
inner_band_product = tf.einsum(
"BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix
) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, 3*wn]
inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d))
rand_band_product = tf.einsum(
"BHLQD,BHLKD->BHLQK", middle_query_matrix,
gathered_key[:, :,
1:-1]) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, r*wn]
rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d))
first_band_product = tf.einsum(
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0]
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d))
last_band_product = tf.einsum(
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1]
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d))
inner_band_product += (1.0 - band_mask) * -10000.0
first_band_product += (1.0 -
tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0
last_band_product += (1.0 -
tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0
rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0
band_product = tf.concat([
first_band_product, inner_band_product, rand_band_product,
last_band_product
], -1) # [b, h, m//wm-4, wm, (5+r)*wn]
attn_weights = tf.nn.softmax(band_product) # [b, h, m//wm-4, wm, (5+r)*wn]
context_layer = tf.einsum(
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :,
wn:4 * wn], exp_blocked_value_matrix
) # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :,
4 * wn:-wn], gathered_value[:, :, 1:-1]
) # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn],
blocked_value_matrix[:, :, 0]
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :,
-wn:], blocked_value_matrix[:, :, -1]
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
second_last_key_mat = tf.concat([
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3],
blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1],
gathered_key[:, :, -1]
], 2) # [b, h, (4+r)*wn, -1]
second_last_value_mat = tf.concat([
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3],
blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1],
gathered_value[:, :, -1]
], 2) # [b, h, (4+r)*wn, -1]
second_last_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_last_seq_pad = tf.concat([
to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32)
], 3)
second_last_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, -1]], 3)
second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d))
second_last_product += (
1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
second_last_attn_weights = tf.nn.softmax(
second_last_product) # [b, h, wm, (4+r)*wn]
second_last_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_last_context_layer = tf.expand_dims(second_last_context_layer, 2)
last_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1],
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
last_product = tf.multiply(last_product, 1.0 / np.sqrt(d))
last_product += (1.0 - to_mask) * -10000.0
last_attn_weights = tf.nn.softmax(last_product) # [b, h, wm, n]
last_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", last_attn_weights,
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
last_context_layer = tf.expand_dims(last_context_layer, 2)
context_layer = tf.concat([
first_context_layer, second_context_layer, context_layer,
second_last_context_layer, last_context_layer
], 2)
context_layer = tf.reshape(context_layer, (b, h, m, -1)) * from_mask
context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
return context_layer
class BigBirdMasks(tf.keras.layers.Layer):
"""Creates bigbird attention masks."""
def __init__(self, block_size, **kwargs):
super().__init__(**kwargs)
self._block_size = block_size
def call(self, inputs):
encoder_shape = tf.shape(inputs)
batch_size, seq_length = encoder_shape[0], encoder_shape[1]
# reshape and cast for blocking
inputs = tf.cast(inputs, dtype=tf.float32)
blocked_encoder_mask = tf.reshape(
inputs, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1))
encoder_to_mask = tf.reshape(inputs, (batch_size, 1, 1, seq_length))
band_mask = create_band_mask_from_inputs(blocked_encoder_mask,
blocked_encoder_mask)
return [band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask]
@tf.keras.utils.register_keras_serializable(package="Text")
class BigBirdAttention(tf.keras.layers.MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments are the same as `MultiHeadAttention` layer.
"""
def __init__(self,
num_rand_blocks=3,
from_block_size=64,
to_block_size=64,
max_rand_mask_length=MAX_SEQ_LEN,
seed=None,
**kwargs):
super().__init__(**kwargs)
self._num_rand_blocks = num_rand_blocks
self._from_block_size = from_block_size
self._to_block_size = to_block_size
self._seed = seed
# Generates random attention.
np.random.seed(self._seed)
# pylint: disable=g-complex-comprehension
rand_attn = [
bigbird_block_rand_mask(
max_rand_mask_length,
max_rand_mask_length,
from_block_size,
to_block_size,
num_rand_blocks,
last_idx=1024) for _ in range(self._num_heads)
]
# pylint: enable=g-complex-comprehension
rand_attn = np.stack(rand_attn, axis=0)
self.rand_attn = tf.constant(rand_attn, dtype=tf.int32)
def _compute_attention(self, query, key, value, attention_mask=None):
(band_mask, encoder_from_mask, encoder_to_mask,
blocked_encoder_mask) = attention_mask
query_shape = tf.shape(query)
from_seq_length = query_shape[1]
to_seq_length = tf.shape(key)[1]
rand_attn = self.rand_attn[:, :(from_seq_length // self._from_block_size -
2)]
return bigbird_block_sparse_attention(
query,
key,
value,
band_mask,
encoder_from_mask,
encoder_to_mask,
blocked_encoder_mask,
blocked_encoder_mask,
num_attention_heads=self._num_heads,
num_rand_blocks=self._num_rand_blocks,
size_per_head=self._key_dim,
batch_size=query_shape[0],
from_seq_length=from_seq_length,
to_seq_length=to_seq_length,
from_block_size=self._from_block_size,
to_block_size=self._to_block_size,
rand_attn=rand_attn)
def call(self, query, value, key=None, attention_mask=None, **kwargs):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output = self._compute_attention(query, key, value,
attention_mask)
attention_output.set_shape([None, None, self._num_heads, self._key_dim])
attention_output = self._output_dense(attention_output)
return attention_output
def get_config(self):
config = {
"num_rand_blocks": self._num_rand_blocks,
"from_block_size": self._from_block_size,
"to_block_size": self._to_block_size,
"seed": self._seed
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.projects.bigbird.attention."""
import tensorflow as tf
from official.nlp.projects.bigbird import attention
class BigbirdAttentionTest(tf.test.TestCase):
def test_attention(self):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
block_size = 64
mask_layer = attention.BigBirdMasks(block_size=block_size)
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = mask_layer(encoder_inputs_mask)
test_layer = attention.BigBirdAttention(
num_heads=num_heads,
key_dim=key_dim,
from_block_size=block_size,
to_block_size=block_size,
seed=0)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
output = test_layer(
query=query,
value=value,
attention_mask=masks)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
def test_config(self):
num_heads = 12
key_dim = 64
block_size = 64
test_layer = attention.BigBirdAttention(
num_heads=num_heads,
key_dim=key_dim,
from_block_size=block_size,
to_block_size=block_size,
seed=0)
print(test_layer.get_config())
new_layer = attention.BigBirdAttention.from_config(
test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.projects.bigbird import attention
@tf.keras.utils.register_keras_serializable(package='Text')
class BigBirdEncoder(tf.keras.Model):
"""Transformer-based encoder network with BigBird attentions.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN,
type_vocab_size=16,
intermediate_size=3072,
block_size=64,
num_rand_blocks=3,
activation=activations.gelu,
dropout_rate=0.1,
attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
embedding_width=None,
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'block_size': block_size,
'num_rand_blocks': num_rand_blocks,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
'embedding_width': embedding_width,
}
word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = tf.keras.layers.Dropout(rate=dropout_rate)(embeddings)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
self._transformer_layers = []
data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(mask)
encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers):
layer = layers.TransformerScaffold(
num_attention_heads,
intermediate_size,
activation,
attention_cls=attention.BigBirdAttention,
attention_cfg=dict(
num_heads=num_attention_heads,
key_dim=attn_head_dim,
kernel_initializer=initializer,
from_block_size=block_size,
to_block_size=block_size,
num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length,
seed=i),
dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate,
kernel_initializer=initializer)
self._transformer_layers.append(layer)
data = layer([data, masks])
encoder_outputs.append(data)
outputs = dict(
sequence_output=encoder_outputs[-1], encoder_outputs=encoder_outputs)
super().__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return self._config_dict
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.projects.bigbird.encoder."""
import numpy as np
import tensorflow as tf
from official.nlp.projects.bigbird import encoder
class BigBirdEncoderTest(tf.test.TestCase):
def test_encoder(self):
sequence_length = 1024
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(2, size=(batch_size, sequence_length))
outputs = network([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs["sequence_output"].shape,
(batch_size, sequence_length, 768))
def test_save_restore(self):
sequence_length = 1024
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(2, size=(batch_size, sequence_length))
inputs = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
ref_outputs = network(inputs)
model_path = self.get_temp_dir() + "/model"
network.save(model_path)
loaded = tf.keras.models.load_model(model_path)
outputs = loaded(inputs)
self.assertAllClose(outputs["sequence_output"],
ref_outputs["sequence_output"])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment