Unverified Commit c4f7eb12 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

add TF glu activation function (#15146)

parent 5f3c57fc
......@@ -69,6 +69,22 @@ def quick_gelu(x):
return x * tf.math.sigmoid(coeff * x)
def glu(x, axis=-1):
"""
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
Args:
`x`: float Tensor to perform activation
`axis`: dimension across which `x` be split in half
Returns:
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
"""
a, b = tf.split(x, 2, axis=axis)
return a * tf.math.sigmoid(b)
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
def approximate_gelu_wrap(x):
......@@ -91,6 +107,7 @@ ACT2FN = {
"tanh": tf.keras.activations.tanh,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
"glu": glu,
}
......
......@@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase):
get_tf_activation("gelu_new")
get_tf_activation("gelu_fast")
get_tf_activation("mish")
get_tf_activation("quick_gelu")
get_tf_activation("glu")
with self.assertRaises(KeyError):
get_tf_activation("bogus")
with self.assertRaises(KeyError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment