Commit ddee474e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272043067
parent 77710731
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import gelu from official.modeling import activations
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
...@@ -30,7 +30,7 @@ class GeluTest(keras_parameterized.TestCase): ...@@ -30,7 +30,7 @@ class GeluTest(keras_parameterized.TestCase):
def test_gelu(self): def test_gelu(self):
expected_data = [[0.14967535, 0., -0.10032465], expected_data = [[0.14967535, 0., -0.10032465],
[-0.15880796, -0.04540223, 2.9963627]] [-0.15880796, -0.04540223, 2.9963627]]
gelu_data = gelu.gelu([[.25, 0, -.25], [-1, -2, 3]]) gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(expected_data, gelu_data) self.assertAllClose(expected_data, gelu_data)
......
...@@ -21,14 +21,14 @@ from __future__ import print_function ...@@ -21,14 +21,14 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import swish from official.modeling import activations
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase): class CustomizedSwishTest(keras_parameterized.TestCase):
def test_gelu(self): def test_gelu(self):
customized_swish_data = swish.swish([[.25, 0, -.25], [-1, -2, 3]]) customized_swish_data = activations.swish([[.25, 0, -.25], [-1, -2, 3]])
swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]]) swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(customized_swish_data, swish_data) self.assertAllClose(customized_swish_data, swish_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment