Commit f8c2a917 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate to tf gelu. Will remove activations/gelu.py after dependencies are cleaned up.

PiperOrigin-RevId: 323499265
parent b8014d55
......@@ -14,12 +14,6 @@
# ==============================================================================
"""Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
......@@ -35,6 +29,4 @@ def gelu(x):
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
return tf.keras.activations.gelu(x, approximate=True)
......@@ -14,32 +14,14 @@
# ==============================================================================
"""Keras layers of XLNet model in TF 2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import numpy as np
import functools
import tensorflow as tf
from official.nlp.xlnet import data_utils
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
gelu = functools.partial(tf.keras.activations.gelu, approximate=True)
def rel_shift(x, klen=-1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment