Unverified Commit a4c1aac3 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

store activation cls instead of function (#10832)

* store cls instead of an obj

* style
parent b2ca39c8
......@@ -24,12 +24,12 @@ from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available():
import torch_npu
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
ACT2CLS = {
"swish": nn.SiLU,
"silu": nn.SiLU,
"mish": nn.Mish,
"gelu": nn.GELU,
"relu": nn.ReLU,
}
......@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
if act_fn in ACT2CLS:
return ACT2CLS[act_fn]()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
class FP32SiLU(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment