"model/models/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "df94175a0fb0356c9b9e9a62b73d908633c08810"
Unverified Commit a75e3198 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix mixed precision issue in TF DistilBert (#6915)

* Remove hard-coded uses of float32 to fix mixed precision use in TF Distilbert

* fix style

* fix gelu dtype issue in TF Distilbert

* fix numeric overflow while using half precision
parent e95d262f
...@@ -76,7 +76,7 @@ def gelu(x): ...@@ -76,7 +76,7 @@ def gelu(x):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
""" """
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.math.sqrt(2.0), dtype=x.dtype)))
return x * cdf return x * cdf
...@@ -168,7 +168,9 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -168,7 +168,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids) inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) position_embeddings = tf.cast(
self.position_embeddings(position_ids), inputs_embeds.dtype
) # (bs, max_seq_length, dim)
embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim) embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
...@@ -261,9 +263,12 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -261,9 +263,12 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) scores_dtype = scores.dtype
# calculate `scores` in `tf.float32` to avoid numeric overflow
scores = tf.cast(scores, dtype=tf.float32) - 1e30 * (1.0 - tf.cast(mask, dtype=tf.float32))
weights = tf.cast(tf.nn.softmax(scores, axis=-1), dtype=scores_dtype) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment