Commit 4c226604 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 333163906
parent 4c693d66
...@@ -127,7 +127,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -127,7 +127,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.dropout_rate, self.dropout_rate,
name='embedding_dropout') name='embedding_dropout')
def call(self, input_ids, token_type_ids=None, training=False): def call(self, input_ids, token_type_ids=None):
word_embedding_out = self.word_embedding(input_ids) word_embedding_out = self.word_embedding(input_ids)
word_embedding_out = tf.concat( word_embedding_out = tf.concat(
[tf.pad(word_embedding_out[:, 1:], ((0, 0), (0, 1), (0, 0))), [tf.pad(word_embedding_out[:, 1:], ((0, 0), (0, 1), (0, 0))),
...@@ -142,7 +142,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -142,7 +142,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
type_embedding_out = self.type_embedding(token_type_ids) type_embedding_out = self.type_embedding(token_type_ids)
embedding_out += type_embedding_out embedding_out += type_embedding_out
embedding_out = self.layer_norm(embedding_out) embedding_out = self.layer_norm(embedding_out)
embedding_out = self.dropout_layer(embedding_out, training=training) embedding_out = self.dropout_layer(embedding_out)
return embedding_out return embedding_out
...@@ -300,7 +300,6 @@ class TransformerLayer(tf.keras.layers.Layer): ...@@ -300,7 +300,6 @@ class TransformerLayer(tf.keras.layers.Layer):
def call(self, def call(self,
input_tensor, input_tensor,
attention_mask=None, attention_mask=None,
training=False,
return_attention_scores=False): return_attention_scores=False):
"""Implementes the forward pass. """Implementes the forward pass.
...@@ -309,7 +308,6 @@ class TransformerLayer(tf.keras.layers.Layer): ...@@ -309,7 +308,6 @@ class TransformerLayer(tf.keras.layers.Layer):
attention_mask: (optional) int32 tensor of shape [batch_size, seq_length, attention_mask: (optional) int32 tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be. positions that should not be.
training: If the model is in training mode.
return_attention_scores: If return attention score. return_attention_scores: If return attention score.
Returns: Returns:
...@@ -326,7 +324,6 @@ class TransformerLayer(tf.keras.layers.Layer): ...@@ -326,7 +324,6 @@ class TransformerLayer(tf.keras.layers.Layer):
f'hidden size {self.hidden_size}')) f'hidden size {self.hidden_size}'))
prev_output = input_tensor prev_output = input_tensor
# input bottleneck # input bottleneck
dense_layer = self.block_layers['bottleneck_input'][0] dense_layer = self.block_layers['bottleneck_input'][0]
layer_norm = self.block_layers['bottleneck_input'][1] layer_norm = self.block_layers['bottleneck_input'][1]
...@@ -355,7 +352,6 @@ class TransformerLayer(tf.keras.layers.Layer): ...@@ -355,7 +352,6 @@ class TransformerLayer(tf.keras.layers.Layer):
key_tensor, key_tensor,
attention_mask, attention_mask,
return_attention_scores=True, return_attention_scores=True,
training=training
) )
attention_output = layer_norm(attention_output + layer_input) attention_output = layer_norm(attention_output + layer_input)
...@@ -375,7 +371,7 @@ class TransformerLayer(tf.keras.layers.Layer): ...@@ -375,7 +371,7 @@ class TransformerLayer(tf.keras.layers.Layer):
dropout_layer = self.block_layers['bottleneck_output'][1] dropout_layer = self.block_layers['bottleneck_output'][1]
layer_norm = self.block_layers['bottleneck_output'][2] layer_norm = self.block_layers['bottleneck_output'][2]
layer_output = bottleneck(layer_output) layer_output = bottleneck(layer_output)
layer_output = dropout_layer(layer_output, training=training) layer_output = dropout_layer(layer_output)
layer_output = layer_norm(layer_output + prev_output) layer_output = layer_norm(layer_output + prev_output)
if return_attention_scores: if return_attention_scores:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment