Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fb0bd7b7
"vscode:/vscode.git/clone" did not exist on "328ade855b653ba803f2a02349f82fd84a4e059c"
Unverified
Commit
fb0bd7b7
authored
Oct 18, 2022
by
Sylvain Gugger
Committed by
GitHub
Oct 18, 2022
Browse files
Fix activations being all the same module (#19728)
parent
14fe3e04
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
14 deletions
+31
-14
src/transformers/activations.py
src/transformers/activations.py
+23
-14
tests/utils/test_activations.py
tests/utils/test_activations.py
+8
-0
No files found.
src/transformers/activations.py
View file @
fb0bd7b7
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
math
from
collections
import
OrderedDict
import
torch
from
packaging
import
version
...
...
@@ -141,21 +142,29 @@ class LinearActivation(nn.Module):
return
input
ACT2FN
=
{
"gelu"
:
GELUActivation
(),
"gelu_10"
:
ClippedGELUActivation
(
-
10
,
10
),
"gelu_fast"
:
FastGELUActivation
(),
"gelu_new"
:
NewGELUActivation
(),
"gelu_python"
:
GELUActivation
(
use_gelu_python
=
True
),
"linear"
:
LinearActivation
(),
"mish"
:
MishActivation
(),
"quick_gelu"
:
QuickGELUActivation
(),
"relu"
:
nn
.
ReLU
(),
"sigmoid"
:
nn
.
Sigmoid
(),
"silu"
:
SiLUActivation
(),
"swish"
:
SiLUActivation
(),
"tanh"
:
nn
.
Tanh
(),
class
ClassInstantier
(
OrderedDict
):
def
__getitem__
(
self
,
key
):
content
=
super
().
__getitem__
(
key
)
cls
,
kwargs
=
content
if
isinstance
(
content
,
tuple
)
else
(
content
,
{})
return
cls
(
**
kwargs
)
ACT2CLS
=
{
"gelu"
:
GELUActivation
,
"gelu_10"
:
(
ClippedGELUActivation
,
{
"min"
:
-
10
,
"max"
:
10
}),
"gelu_fast"
:
FastGELUActivation
,
"gelu_new"
:
NewGELUActivation
,
"gelu_python"
:
(
GELUActivation
,
{
"use_gelu_python"
:
True
}),
"linear"
:
LinearActivation
,
"mish"
:
MishActivation
,
"quick_gelu"
:
QuickGELUActivation
,
"relu"
:
nn
.
ReLU
,
"sigmoid"
:
nn
.
Sigmoid
,
"silu"
:
SiLUActivation
,
"swish"
:
SiLUActivation
,
"tanh"
:
nn
.
Tanh
,
}
ACT2FN
=
ClassInstantier
(
ACT2CLS
)
def
get_activation
(
activation_string
):
...
...
tests/utils/test_activations.py
View file @
fb0bd7b7
...
...
@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase):
get_activation
(
"bogus"
)
with
self
.
assertRaises
(
KeyError
):
get_activation
(
None
)
def
test_activations_are_distinct_objects
(
self
):
act1
=
get_activation
(
"gelu"
)
act1
.
a
=
1
act2
=
get_activation
(
"gelu"
)
self
.
assertEqual
(
act1
.
a
,
1
)
with
self
.
assertRaises
(
AttributeError
):
_
=
act2
.
a
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment