Unverified Commit 180e981d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Replace swish with silu (#32459)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b84c426a
...@@ -36,10 +36,13 @@ def get_activation(name: str = "relu") -> torch.nn.Module: ...@@ -36,10 +36,13 @@ def get_activation(name: str = "relu") -> torch.nn.Module:
if name == "gelu": if name == "gelu":
return nn.GELU() return nn.GELU()
if name == "swish": if name == "swish":
return Swish() return nn.SiLU()
if name == "sigmoid": if name == "sigmoid":
return torch.nn.Sigmoid() return nn.Sigmoid()
return nn.Identity() if name == "identity":
return nn.Identity()
raise NotImplementedError(name)
def adaptive_enc_mask( def adaptive_enc_mask(
...@@ -93,44 +96,14 @@ def adaptive_enc_mask( ...@@ -93,44 +96,14 @@ def adaptive_enc_mask(
return mask_left & mask_right return mask_left & mask_right
class Swish(nn.Module):
"""Implement Swish activation module.
From https://arxiv.org/pdf/2005.03191.pdf
"""
def __init__(self) -> None:
super().__init__()
self.act_fn = nn.Sigmoid()
def forward(self, x: Tensor) -> Tensor:
"""Apply Swish function
Args:
x: torch.Tensor
Input.
"""
return x * self.act_fn(x)
class GLU(nn.Module): class GLU(nn.Module):
"""Implement Gated Linear Unit (GLU) module""" """Implement Gated Linear Unit (GLU) module"""
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.act_name = act_name.lower() self.act_fn = get_activation(act_name)
if self.act_name == "relu":
self.act_fn = nn.ReLU(inplace=True)
elif self.act_name == "gelu":
self.act_fn = nn.GELU()
elif self.act_name == "swish":
self.act_fn = Swish()
elif self.act_name == "sigmoid":
self.act_fn = nn.Sigmoid()
else:
self.act_fn = nn.Identity()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""GLU forward """GLU forward
...@@ -204,16 +177,7 @@ class GLUPointWiseConv(nn.Module): ...@@ -204,16 +177,7 @@ class GLUPointWiseConv(nn.Module):
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
) )
if glu_type == "sigmoid": self.glu_act = get_activation(glu_type)
self.glu_act = nn.Sigmoid()
elif glu_type == "relu":
self.glu_act = nn.ReLU()
elif glu_type == "gelu":
self.glu_act = nn.GELU()
elif glu_type == "swish":
self.glu_act = Swish()
else:
raise ValueError(f"Unsupported activation type {self.glu_act}")
if bias_in_glu: if bias_in_glu:
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment