Commit 09cb3dff authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342726521
parent a1b04a45
......@@ -71,6 +71,9 @@ class MobileBertEncoderConfig(hyperparams.Config):
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: whether to share linear transformation for keys
and queries.
num_feedforward_networks: number of stacked feed-forward networks.
......@@ -94,6 +97,7 @@ class MobileBertEncoderConfig(hyperparams.Config):
attention_probs_dropout_prob: float = 0.1
intra_bottleneck_size: int = 1024
initializer_range: float = 0.02
use_bottleneck_attention: bool = False
key_query_shared_bottleneck: bool = False
num_feedforward_networks: int = 1
normalization_type: str = "layer_norm"
......@@ -253,6 +257,7 @@ def build_encoder(
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
initializer_range=encoder_cfg.initializer_range,
use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type,
......
......@@ -163,6 +163,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type='no_norm',
......@@ -181,6 +182,9 @@ class MobileBertTransformer(tf.keras.layers.Layer):
attention_probs_dropout_prob: Dropout probability of the attention
probabilities.
intra_bottleneck_size: Size of bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks.
......@@ -203,6 +207,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type
......@@ -328,7 +333,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_input = dense_layer(prev_output)
layer_input = layer_norm(layer_input)
if self.key_query_shared_bottleneck:
if self.use_bottleneck_attention:
key_tensor = layer_input
query_tensor = layer_input
value_tensor = layer_input
elif self.key_query_shared_bottleneck:
dense_layer = self.block_layers['kq_shared_bottleneck'][0]
layer_norm = self.block_layers['kq_shared_bottleneck'][1]
shared_attention_input = dense_layer(prev_output)
......@@ -337,9 +346,9 @@ class MobileBertTransformer(tf.keras.layers.Layer):
query_tensor = shared_attention_input
value_tensor = prev_output
else:
key_tensor = layer_input
query_tensor = layer_input
value_tensor = layer_input
key_tensor = prev_output
query_tensor = prev_output
value_tensor = prev_output
# attention layer
attention_layer = self.block_layers['attention'][0]
......
......@@ -37,6 +37,7 @@ class MobileBERTEncoder(tf.keras.Model):
attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128,
initializer_range=0.02,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type='no_norm',
......@@ -62,6 +63,9 @@ class MobileBERTEncoder(tf.keras.Model):
intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks.
......@@ -98,6 +102,7 @@ class MobileBERTEncoder(tf.keras.Model):
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
intra_bottleneck_size=intra_bottleneck_size,
use_bottleneck_attention=use_bottleneck_attention,
key_query_shared_bottleneck=key_query_shared_bottleneck,
num_feedforward_networks=num_feedforward_networks,
normalization_type=normalization_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment