Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f09c45e0
Unverified
Commit
f09c45e0
authored
Apr 19, 2022
by
Joao Gante
Committed by
GitHub
Apr 19, 2022
Browse files
TF: Add sigmoid activation function (#16819)
parent
74814574
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
31 deletions
+34
-31
src/transformers/activations.py
src/transformers/activations.py
+9
-9
src/transformers/activations_tf.py
src/transformers/activations_tf.py
+8
-7
tests/utils/test_activations.py
tests/utils/test_activations.py
+9
-8
tests/utils/test_activations_tf.py
tests/utils/test_activations_tf.py
+8
-7
No files found.
src/transformers/activations.py
View file @
f09c45e0
...
...
@@ -152,19 +152,19 @@ class LinearActivation(nn.Module):
ACT2FN
=
{
"relu"
:
nn
.
ReLU
(),
"silu"
:
SiLUActivation
(),
"swish"
:
SiLUActivation
(),
"gelu"
:
GELUActivation
(),
"tanh"
:
nn
.
Tanh
(),
"gelu_python"
:
GELUActivation
(
use_gelu_python
=
True
),
"gelu_new"
:
NewGELUActivation
(),
"gelu_fast"
:
FastGELUActivation
(),
"quick_gelu"
:
QuickGELUActivation
(),
"gelu_10"
:
ClippedGELUActivation
(
-
10
,
10
),
"mish"
:
MishActivation
(),
"gelu_fast"
:
FastGELUActivation
(),
"gelu_new"
:
NewGELUActivation
(),
"gelu_python"
:
GELUActivation
(
use_gelu_python
=
True
),
"linear"
:
LinearActivation
(),
"mish"
:
MishActivation
(),
"quick_gelu"
:
QuickGELUActivation
(),
"relu"
:
nn
.
ReLU
(),
"sigmoid"
:
nn
.
Sigmoid
(),
"silu"
:
SiLUActivation
(),
"swish"
:
SiLUActivation
(),
"tanh"
:
nn
.
Tanh
(),
}
...
...
src/transformers/activations_tf.py
View file @
f09c45e0
...
...
@@ -113,16 +113,17 @@ else:
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
tf
.
keras
.
activations
.
relu
,
"swish"
:
tf
.
keras
.
activations
.
swish
,
"silu"
:
tf
.
keras
.
activations
.
swish
,
"gelu_10"
:
gelu_10
,
"gelu_fast"
:
gelu_fast
,
"gelu_new"
:
gelu_new
,
"glu"
:
glu
,
"mish"
:
mish
,
"tanh"
:
tf
.
keras
.
activations
.
tanh
,
"gelu_fast"
:
gelu_fast
,
"quick_gelu"
:
quick_gelu
,
"gelu_10"
:
gelu_10
,
"glu"
:
glu
,
"relu"
:
tf
.
keras
.
activations
.
relu
,
"sigmoid"
:
tf
.
keras
.
activations
.
sigmoid
,
"silu"
:
tf
.
keras
.
activations
.
swish
,
"swish"
:
tf
.
keras
.
activations
.
swish
,
"tanh"
:
tf
.
keras
.
activations
.
tanh
,
}
...
...
tests/utils/test_activations.py
View file @
f09c45e0
...
...
@@ -46,18 +46,19 @@ class TestActivations(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
y_gelu
*
clipped_mask
,
y_gelu_10
*
clipped_mask
))
def
test_get_activation
(
self
):
get_activation
(
"swish"
)
get_activation
(
"silu"
)
get_activation
(
"relu"
)
get_activation
(
"tanh"
)
get_activation
(
"gelu_new"
)
get_activation
(
"gelu"
)
get_activation
(
"gelu_10"
)
get_activation
(
"gelu_fast"
)
get_activation
(
"gelu_new"
)
get_activation
(
"gelu_python"
)
get_activation
(
"gelu_10"
)
get_activation
(
"quick_gelu"
)
get_activation
(
"mish"
)
get_activation
(
"linear"
)
get_activation
(
"mish"
)
get_activation
(
"quick_gelu"
)
get_activation
(
"relu"
)
get_activation
(
"sigmoid"
)
get_activation
(
"silu"
)
get_activation
(
"swish"
)
get_activation
(
"tanh"
)
with
self
.
assertRaises
(
KeyError
):
get_activation
(
"bogus"
)
with
self
.
assertRaises
(
KeyError
):
...
...
tests/utils/test_activations_tf.py
View file @
f09c45e0
...
...
@@ -42,17 +42,18 @@ class TestTFActivations(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
y_gelu
*
clipped_mask
,
y_gelu_10
*
clipped_mask
))
def
test_get_activation
(
self
):
get_tf_activation
(
"swish"
)
get_tf_activation
(
"silu"
)
get_tf_activation
(
"gelu"
)
get_tf_activation
(
"relu"
)
get_tf_activation
(
"tanh"
)
get_tf_activation
(
"gelu_new"
)
get_tf_activation
(
"gelu_fast"
)
get_tf_activation
(
"gelu_10"
)
get_tf_activation
(
"gelu_fast"
)
get_tf_activation
(
"gelu_new"
)
get_tf_activation
(
"glu"
)
get_tf_activation
(
"mish"
)
get_tf_activation
(
"quick_gelu"
)
get_tf_activation
(
"glu"
)
get_tf_activation
(
"relu"
)
get_tf_activation
(
"sigmoid"
)
get_tf_activation
(
"silu"
)
get_tf_activation
(
"swish"
)
get_tf_activation
(
"tanh"
)
with
self
.
assertRaises
(
KeyError
):
get_tf_activation
(
"bogus"
)
with
self
.
assertRaises
(
KeyError
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment