Unverified Commit 4fca874e authored by Jay's avatar Jay Committed by GitHub
Browse files

Remove hard-coded uses of float32 to fix mixed precision use (#6648)

parent 0344428f
...@@ -215,8 +215,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -215,8 +215,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids) inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
...@@ -281,7 +281,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -281,7 +281,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.matmul( attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k) ) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
...@@ -613,6 +613,8 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -613,6 +613,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...@@ -626,7 +628,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -626,7 +628,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed # Prepare head mask if needed
...@@ -640,7 +642,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -640,7 +642,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
......
...@@ -134,8 +134,8 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -134,8 +134,8 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids) inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
...@@ -194,7 +194,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel): ...@@ -194,7 +194,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
config_class = ElectraConfig config_class = ElectraConfig
base_model_prefix = "electra" base_model_prefix = "electra"
def get_extended_attention_mask(self, attention_mask, input_shape): def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
...@@ -211,7 +211,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel): ...@@ -211,7 +211,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) extended_attention_mask = tf.cast(extended_attention_mask, dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask return extended_attention_mask
...@@ -314,11 +314,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -314,11 +314,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask)
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
head_mask = self.get_head_mask(head_mask)
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.embeddings_project(hidden_states, training=training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment