Unverified Commit 3f290e6c authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix mixed precision in TF models (#9163)

* Fix Gelu precision

* Fix gelu_fast

* Naming

* Fix usage and apply style

* add TF gelu approximate version

* add TF gelu approximate version

* add TF gelu approximate version

* Apply style

* Fix albert

* Remove the usage of the Activation layer
parent 248fa1ae
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
import math import math
import tensorflow as tf import tensorflow as tf
from packaging import version
def gelu(x): def _gelu(x):
""" """
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
...@@ -25,12 +26,12 @@ def gelu(x): ...@@ -25,12 +26,12 @@ def gelu(x):
https://arxiv.org/abs/1606.08415 https://arxiv.org/abs/1606.08415
""" """
x = tf.convert_to_tensor(x) x = tf.convert_to_tensor(x)
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
return x * cdf return x * cdf
def gelu_new(x): def _gelu_new(x):
""" """
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
...@@ -56,21 +57,33 @@ def mish(x): ...@@ -56,21 +57,33 @@ def mish(x):
def gelu_fast(x): def gelu_fast(x):
x = tf.convert_to_tensor(x) x = tf.convert_to_tensor(x)
coeff1 = tf.cast(7978845608, x.dtype) coeff1 = tf.cast(0.7978845608, x.dtype)
coeff2 = tf.cast(0.044715, x.dtype) coeff2 = tf.cast(0.044715, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
def approximate_gelu_wrap(x):
return tf.keras.activations.gelu(x, approximate=True)
gelu = tf.keras.activations.gelu
gelu_new = approximate_gelu_wrap
else:
gelu = _gelu
gelu_new = _gelu_new
ACT2FN = { ACT2FN = {
"gelu": tf.keras.layers.Activation(gelu), "gelu": gelu,
"relu": tf.keras.activations.relu, "relu": tf.keras.activations.relu,
"swish": tf.keras.activations.swish, "swish": tf.keras.activations.swish,
"silu": tf.keras.activations.swish, "silu": tf.keras.activations.swish,
"gelu_new": tf.keras.layers.Activation(gelu_new), "gelu_new": gelu_new,
"mish": tf.keras.layers.Activation(mish), "mish": mish,
"tanh": tf.keras.activations.tanh, "tanh": tf.keras.activations.tanh,
"gelu_fast": tf.keras.layers.Activation(gelu_fast), "gelu_fast": gelu_fast,
} }
......
...@@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.activation(inputs=hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states) hidden_states = self.LayerNorm(inputs=hidden_states)
seq_length = shape_list(tensor=hidden_states)[1] seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
......
...@@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer): ...@@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer): ...@@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer): ...@@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer): ...@@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer): ...@@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment