Unverified Commit a4c1aac3 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

store activation cls instead of function (#10832)

* store cls instead of an obj

* style
parent b2ca39c8
...@@ -24,12 +24,12 @@ from ..utils.import_utils import is_torch_npu_available, is_torch_version ...@@ -24,12 +24,12 @@ from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available(): if is_torch_npu_available():
import torch_npu import torch_npu
ACTIVATION_FUNCTIONS = { ACT2CLS = {
"swish": nn.SiLU(), "swish": nn.SiLU,
"silu": nn.SiLU(), "silu": nn.SiLU,
"mish": nn.Mish(), "mish": nn.Mish,
"gelu": nn.GELU(), "gelu": nn.GELU,
"relu": nn.ReLU(), "relu": nn.ReLU,
} }
...@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module: ...@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
""" """
act_fn = act_fn.lower() act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS: if act_fn in ACT2CLS:
return ACTIVATION_FUNCTIONS[act_fn] return ACT2CLS[act_fn]()
else: else:
raise ValueError(f"Unsupported activation function: {act_fn}") raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
class FP32SiLU(nn.Module): class FP32SiLU(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment