Unverified Commit f09c45e0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Add sigmoid activation function (#16819)

parent 74814574
...@@ -152,19 +152,19 @@ class LinearActivation(nn.Module): ...@@ -152,19 +152,19 @@ class LinearActivation(nn.Module):
ACT2FN = { ACT2FN = {
"relu": nn.ReLU(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"gelu": GELUActivation(), "gelu": GELUActivation(),
"tanh": nn.Tanh(),
"gelu_python": GELUActivation(use_gelu_python=True),
"gelu_new": NewGELUActivation(),
"gelu_fast": FastGELUActivation(),
"quick_gelu": QuickGELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10), "gelu_10": ClippedGELUActivation(-10, 10),
"mish": MishActivation(), "gelu_fast": FastGELUActivation(),
"gelu_new": NewGELUActivation(),
"gelu_python": GELUActivation(use_gelu_python=True),
"linear": LinearActivation(), "linear": LinearActivation(),
"mish": MishActivation(),
"quick_gelu": QuickGELUActivation(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(), "sigmoid": nn.Sigmoid(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"tanh": nn.Tanh(),
} }
......
...@@ -113,16 +113,17 @@ else: ...@@ -113,16 +113,17 @@ else:
ACT2FN = { ACT2FN = {
"gelu": gelu, "gelu": gelu,
"relu": tf.keras.activations.relu, "gelu_10": gelu_10,
"swish": tf.keras.activations.swish, "gelu_fast": gelu_fast,
"silu": tf.keras.activations.swish,
"gelu_new": gelu_new, "gelu_new": gelu_new,
"glu": glu,
"mish": mish, "mish": mish,
"tanh": tf.keras.activations.tanh,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu, "quick_gelu": quick_gelu,
"gelu_10": gelu_10, "relu": tf.keras.activations.relu,
"glu": glu, "sigmoid": tf.keras.activations.sigmoid,
"silu": tf.keras.activations.swish,
"swish": tf.keras.activations.swish,
"tanh": tf.keras.activations.tanh,
} }
......
...@@ -46,18 +46,19 @@ class TestActivations(unittest.TestCase): ...@@ -46,18 +46,19 @@ class TestActivations(unittest.TestCase):
self.assertTrue(torch.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)) self.assertTrue(torch.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
def test_get_activation(self): def test_get_activation(self):
get_activation("swish") get_activation("gelu")
get_activation("silu") get_activation("gelu_10")
get_activation("relu")
get_activation("tanh")
get_activation("gelu_new")
get_activation("gelu_fast") get_activation("gelu_fast")
get_activation("gelu_new")
get_activation("gelu_python") get_activation("gelu_python")
get_activation("gelu_10")
get_activation("quick_gelu")
get_activation("mish")
get_activation("linear") get_activation("linear")
get_activation("mish")
get_activation("quick_gelu")
get_activation("relu")
get_activation("sigmoid") get_activation("sigmoid")
get_activation("silu")
get_activation("swish")
get_activation("tanh")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_activation("bogus") get_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
......
...@@ -42,17 +42,18 @@ class TestTFActivations(unittest.TestCase): ...@@ -42,17 +42,18 @@ class TestTFActivations(unittest.TestCase):
self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)) self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
def test_get_activation(self): def test_get_activation(self):
get_tf_activation("swish")
get_tf_activation("silu")
get_tf_activation("gelu") get_tf_activation("gelu")
get_tf_activation("relu")
get_tf_activation("tanh")
get_tf_activation("gelu_new")
get_tf_activation("gelu_fast")
get_tf_activation("gelu_10") get_tf_activation("gelu_10")
get_tf_activation("gelu_fast")
get_tf_activation("gelu_new")
get_tf_activation("glu")
get_tf_activation("mish") get_tf_activation("mish")
get_tf_activation("quick_gelu") get_tf_activation("quick_gelu")
get_tf_activation("glu") get_tf_activation("relu")
get_tf_activation("sigmoid")
get_tf_activation("silu")
get_tf_activation("swish")
get_tf_activation("tanh")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_tf_activation("bogus") get_tf_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment