Unverified Commit 90150733 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix mixed precision issue for GPT2 (#8572)

* Fix mixed precision issue for GPT2

* Forgot one cast

* oops

* Forgotten casts
parent 1073a2bd
...@@ -97,7 +97,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -97,7 +97,7 @@ class TFAttention(tf.keras.layers.Layer):
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
w = w / tf.math.sqrt(dk) w = w / tf.math.sqrt(dk)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
...@@ -352,6 +352,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -352,6 +352,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_embeds = self.wte(token_type_ids, mode="embedding") token_type_embeds = self.wte(token_type_ids, mode="embedding")
else: else:
token_type_embeds = 0 token_type_embeds = 0
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training) hidden_states = self.drop(hidden_states, training=training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment