Commit 4d34474e authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342726521
parent 0fbeba16
...@@ -71,6 +71,9 @@ class MobileBertEncoderConfig(hyperparams.Config): ...@@ -71,6 +71,9 @@ class MobileBertEncoderConfig(hyperparams.Config):
intra_bottleneck_size: the size of bottleneck. intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: whether to share linear transformation for keys key_query_shared_bottleneck: whether to share linear transformation for keys
and queries. and queries.
num_feedforward_networks: number of stacked feed-forward networks. num_feedforward_networks: number of stacked feed-forward networks.
...@@ -94,6 +97,7 @@ class MobileBertEncoderConfig(hyperparams.Config): ...@@ -94,6 +97,7 @@ class MobileBertEncoderConfig(hyperparams.Config):
attention_probs_dropout_prob: float = 0.1 attention_probs_dropout_prob: float = 0.1
intra_bottleneck_size: int = 1024 intra_bottleneck_size: int = 1024
initializer_range: float = 0.02 initializer_range: float = 0.02
use_bottleneck_attention: bool = False
key_query_shared_bottleneck: bool = False key_query_shared_bottleneck: bool = False
num_feedforward_networks: int = 1 num_feedforward_networks: int = 1
normalization_type: str = "layer_norm" normalization_type: str = "layer_norm"
...@@ -253,6 +257,7 @@ def build_encoder( ...@@ -253,6 +257,7 @@ def build_encoder(
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob, attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
initializer_range=encoder_cfg.initializer_range, initializer_range=encoder_cfg.initializer_range,
use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck, key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks, num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type, normalization_type=encoder_cfg.normalization_type,
......
...@@ -163,6 +163,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -163,6 +163,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128, intra_bottleneck_size=128,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True, key_query_shared_bottleneck=True,
num_feedforward_networks=4, num_feedforward_networks=4,
normalization_type='no_norm', normalization_type='no_norm',
...@@ -181,6 +182,9 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -181,6 +182,9 @@ class MobileBertTransformer(tf.keras.layers.Layer):
attention_probs_dropout_prob: Dropout probability of the attention attention_probs_dropout_prob: Dropout probability of the attention
probabilities. probabilities.
intra_bottleneck_size: Size of bottleneck. intra_bottleneck_size: Size of bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
...@@ -203,6 +207,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -203,6 +207,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.hidden_dropout_prob = hidden_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.intra_bottleneck_size = intra_bottleneck_size self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type self.normalization_type = normalization_type
...@@ -328,7 +333,11 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -328,7 +333,11 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_input = dense_layer(prev_output) layer_input = dense_layer(prev_output)
layer_input = layer_norm(layer_input) layer_input = layer_norm(layer_input)
if self.key_query_shared_bottleneck: if self.use_bottleneck_attention:
key_tensor = layer_input
query_tensor = layer_input
value_tensor = layer_input
elif self.key_query_shared_bottleneck:
dense_layer = self.block_layers['kq_shared_bottleneck'][0] dense_layer = self.block_layers['kq_shared_bottleneck'][0]
layer_norm = self.block_layers['kq_shared_bottleneck'][1] layer_norm = self.block_layers['kq_shared_bottleneck'][1]
shared_attention_input = dense_layer(prev_output) shared_attention_input = dense_layer(prev_output)
...@@ -337,9 +346,9 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -337,9 +346,9 @@ class MobileBertTransformer(tf.keras.layers.Layer):
query_tensor = shared_attention_input query_tensor = shared_attention_input
value_tensor = prev_output value_tensor = prev_output
else: else:
key_tensor = layer_input key_tensor = prev_output
query_tensor = layer_input query_tensor = prev_output
value_tensor = layer_input value_tensor = prev_output
# attention layer # attention layer
attention_layer = self.block_layers['attention'][0] attention_layer = self.block_layers['attention'][0]
......
...@@ -37,6 +37,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -37,6 +37,7 @@ class MobileBERTEncoder(tf.keras.Model):
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128, intra_bottleneck_size=128,
initializer_range=0.02, initializer_range=0.02,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True, key_query_shared_bottleneck=True,
num_feedforward_networks=4, num_feedforward_networks=4,
normalization_type='no_norm', normalization_type='no_norm',
...@@ -62,6 +63,9 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -62,6 +63,9 @@ class MobileBERTEncoder(tf.keras.Model):
intra_bottleneck_size: Size of bottleneck. intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
...@@ -98,6 +102,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -98,6 +102,7 @@ class MobileBERTEncoder(tf.keras.Model):
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
intra_bottleneck_size=intra_bottleneck_size, intra_bottleneck_size=intra_bottleneck_size,
use_bottleneck_attention=use_bottleneck_attention,
key_query_shared_bottleneck=key_query_shared_bottleneck, key_query_shared_bottleneck=key_query_shared_bottleneck,
num_feedforward_networks=num_feedforward_networks, num_feedforward_networks=num_feedforward_networks,
normalization_type=normalization_type, normalization_type=normalization_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment