Unverified Commit fb0bd7b7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix activations being all the same module (#19728)

parent 14fe3e04
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import OrderedDict
import torch
from packaging import version
......@@ -141,21 +142,29 @@ class LinearActivation(nn.Module):
return input
ACT2FN = {
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
"gelu_fast": FastGELUActivation(),
"gelu_new": NewGELUActivation(),
"gelu_python": GELUActivation(use_gelu_python=True),
"linear": LinearActivation(),
"mish": MishActivation(),
"quick_gelu": QuickGELUActivation(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"tanh": nn.Tanh(),
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
......
......@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase):
get_activation("bogus")
with self.assertRaises(KeyError):
get_activation(None)
def test_activations_are_distinct_objects(self):
act1 = get_activation("gelu")
act1.a = 1
act2 = get_activation("gelu")
self.assertEqual(act1.a, 1)
with self.assertRaises(AttributeError):
_ = act2.a
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment