Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c4f7eb12
"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "da7a8df0d2582c8dc91e5afafe51300899c91392"
Unverified
Commit
c4f7eb12
authored
Jan 14, 2022
by
Joao Gante
Committed by
GitHub
Jan 14, 2022
Browse files
add TF glu activation function (#15146)
parent
5f3c57fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
0 deletions
+19
-0
src/transformers/activations_tf.py
src/transformers/activations_tf.py
+17
-0
tests/test_activations_tf.py
tests/test_activations_tf.py
+2
-0
No files found.
src/transformers/activations_tf.py
View file @
c4f7eb12
...
@@ -69,6 +69,22 @@ def quick_gelu(x):
...
@@ -69,6 +69,22 @@ def quick_gelu(x):
return
x
*
tf
.
math
.
sigmoid
(
coeff
*
x
)
return
x
*
tf
.
math
.
sigmoid
(
coeff
*
x
)
def
glu
(
x
,
axis
=-
1
):
"""
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
Args:
`x`: float Tensor to perform activation
`axis`: dimension across which `x` be split in half
Returns:
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
"""
a
,
b
=
tf
.
split
(
x
,
2
,
axis
=
axis
)
return
a
*
tf
.
math
.
sigmoid
(
b
)
if
version
.
parse
(
tf
.
version
.
VERSION
)
>=
version
.
parse
(
"2.4"
):
if
version
.
parse
(
tf
.
version
.
VERSION
)
>=
version
.
parse
(
"2.4"
):
def
approximate_gelu_wrap
(
x
):
def
approximate_gelu_wrap
(
x
):
...
@@ -91,6 +107,7 @@ ACT2FN = {
...
@@ -91,6 +107,7 @@ ACT2FN = {
"tanh"
:
tf
.
keras
.
activations
.
tanh
,
"tanh"
:
tf
.
keras
.
activations
.
tanh
,
"gelu_fast"
:
gelu_fast
,
"gelu_fast"
:
gelu_fast
,
"quick_gelu"
:
quick_gelu
,
"quick_gelu"
:
quick_gelu
,
"glu"
:
glu
,
}
}
...
...
tests/test_activations_tf.py
View file @
c4f7eb12
...
@@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase):
...
@@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase):
get_tf_activation
(
"gelu_new"
)
get_tf_activation
(
"gelu_new"
)
get_tf_activation
(
"gelu_fast"
)
get_tf_activation
(
"gelu_fast"
)
get_tf_activation
(
"mish"
)
get_tf_activation
(
"mish"
)
get_tf_activation
(
"quick_gelu"
)
get_tf_activation
(
"glu"
)
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
get_tf_activation
(
"bogus"
)
get_tf_activation
(
"bogus"
)
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment