Unverified Commit fb0bd7b7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix activations being all the same module (#19728)

parent 14fe3e04
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from collections import OrderedDict
import torch import torch
from packaging import version from packaging import version
...@@ -141,21 +142,29 @@ class LinearActivation(nn.Module): ...@@ -141,21 +142,29 @@ class LinearActivation(nn.Module):
return input return input
ACT2FN = { class ClassInstantier(OrderedDict):
"gelu": GELUActivation(), def __getitem__(self, key):
"gelu_10": ClippedGELUActivation(-10, 10), content = super().__getitem__(key)
"gelu_fast": FastGELUActivation(), cls, kwargs = content if isinstance(content, tuple) else (content, {})
"gelu_new": NewGELUActivation(), return cls(**kwargs)
"gelu_python": GELUActivation(use_gelu_python=True),
"linear": LinearActivation(),
"mish": MishActivation(), ACT2CLS = {
"quick_gelu": QuickGELUActivation(), "gelu": GELUActivation,
"relu": nn.ReLU(), "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"sigmoid": nn.Sigmoid(), "gelu_fast": FastGELUActivation,
"silu": SiLUActivation(), "gelu_new": NewGELUActivation,
"swish": SiLUActivation(), "gelu_python": (GELUActivation, {"use_gelu_python": True}),
"tanh": nn.Tanh(), "linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
} }
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string): def get_activation(activation_string):
......
...@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase): ...@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase):
get_activation("bogus") get_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_activation(None) get_activation(None)
def test_activations_are_distinct_objects(self):
act1 = get_activation("gelu")
act1.a = 1
act2 = get_activation("gelu")
self.assertEqual(act1.a, 1)
with self.assertRaises(AttributeError):
_ = act2.a
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment