Unverified Commit 8a83cf2e authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Allow custom activation in SqueezeExcitation of EfficientNet (#4448)



* allow custom activation in SqueezeExcitation

* use ReLU as the default activation

* make scale activation parameterizable
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 2e0949e2
...@@ -32,17 +32,25 @@ model_urls = { ...@@ -32,17 +32,25 @@ model_urls = {
class SqueezeExcitation(nn.Module): class SqueezeExcitation(nn.Module):
def __init__(self, input_channels: int, squeeze_channels: int): def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., nn.Module] = nn.ReLU,
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
) -> None:
super().__init__() super().__init__()
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor: def _scale(self, input: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1) scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale) scale = self.fc1(scale)
scale = F.silu(scale, inplace=True) scale = self.activation(scale)
scale = self.fc2(scale) scale = self.fc2(scale)
return scale.sigmoid() return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input) scale = self._scale(input)
...@@ -108,7 +116,7 @@ class MBConv(nn.Module): ...@@ -108,7 +116,7 @@ class MBConv(nn.Module):
# squeeze and excitation # squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4) squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels)) layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project # project
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment