Unverified Commit ff126ae2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Replace MobileNetV3's SqueezeExcitation with EfficientNet's one (#4487)

* Reuse EfficientNet SE layer.

* Deprecating the mobilenetv3.SqueezeExcitation layer.

* Passing the right activation on quantization.

* Making strict named param.

* Set default params if missing.

* Fixing typos.
parent 13bd09dd
import warnings
import torch
from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
from .efficientnet import SqueezeExcitation as SElayer
from .mobilenetv2 import _make_divisible, ConvBNActivation
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
......@@ -18,25 +19,16 @@ model_urls = {
}
class SqueezeExcitation(nn.Module):
# Implemented as described at Figure 4 of the MobileNetV3 paper
class SqueezeExcitation(SElayer):
"""DEPRECATED
"""
def __init__(self, input_channels: int, squeeze_factor: int = 4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.relu(scale)
scale = self.fc2(scale)
return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input, True)
return scale * input
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation
delattr(self, 'activation')
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
class InvertedResidualConfig:
......@@ -60,7 +52,7 @@ class InvertedResidualConfig:
class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
......@@ -81,7 +73,8 @@ class InvertedResidual(nn.Module):
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
layers.append(se_layer(cnf.expanded_channels))
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
......
import torch
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
SqueezeExcitation, model_urls, _mobilenet_v3_conf
from ..efficientnet import SqueezeExcitation as SElayer
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
from .utils import _replace_relu
......@@ -16,16 +17,53 @@ quant_model_urls = {
}
class QuantizableSqueezeExcitation(SqueezeExcitation):
class QuantizableSqueezeExcitation(SElayer):
_version = 2
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["scale_activation"] = nn.Hardsigmoid
super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional()
def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input, False), input)
return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'relu'], inplace=True)
fuse_modules(self, ['fc1', 'activation'], inplace=True)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
default_state_dict = {
"scale_activation.activation_post_process.scale": torch.tensor([1.]),
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
}
for k, v in default_state_dict.items():
full_key = prefix + k
if full_key not in state_dict:
state_dict[full_key] = v
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class QuantizableInvertedResidual(InvertedResidual):
......@@ -78,7 +116,7 @@ def _load_weights(
arch: str,
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool,
progress: bool
) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment