Unverified Commit ff126ae2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Replace MobileNetV3's SqueezeExcitation with EfficientNet's one (#4487)

* Reuse EfficientNet SE layer.

* Deprecating the mobilenetv3.SqueezeExcitation layer.

* Passing the right activation on quantization.

* Making strict named param.

* Set default params if missing.

* Fixing typos.
parent 13bd09dd
import warnings
import torch import torch
from functools import partial from functools import partial
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation from .efficientnet import SqueezeExcitation as SElayer
from .mobilenetv2 import _make_divisible, ConvBNActivation
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
...@@ -18,25 +19,16 @@ model_urls = { ...@@ -18,25 +19,16 @@ model_urls = {
} }
class SqueezeExcitation(nn.Module): class SqueezeExcitation(SElayer):
# Implemented as described at Figure 4 of the MobileNetV3 paper """DEPRECATED
"""
def __init__(self, input_channels: int, squeeze_factor: int = 4): def __init__(self, input_channels: int, squeeze_factor: int = 4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = nn.ReLU(inplace=True) self.relu = self.activation
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) delattr(self, 'activation')
warnings.warn(
def _scale(self, input: Tensor, inplace: bool) -> Tensor: "This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.relu(scale)
scale = self.fc2(scale)
return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input, True)
return scale * input
class InvertedResidualConfig: class InvertedResidualConfig:
...@@ -60,7 +52,7 @@ class InvertedResidualConfig: ...@@ -60,7 +52,7 @@ class InvertedResidualConfig:
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper # Implemented as described at section 5 of MobileNetV3 paper
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation): se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)):
super().__init__() super().__init__()
if not (1 <= cnf.stride <= 2): if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value') raise ValueError('illegal stride value')
...@@ -81,7 +73,8 @@ class InvertedResidual(nn.Module): ...@@ -81,7 +73,8 @@ class InvertedResidual(nn.Module):
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se: if cnf.use_se:
layers.append(se_layer(cnf.expanded_channels)) squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project # project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
......
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ from ..efficientnet import SqueezeExcitation as SElayer
SqueezeExcitation, model_urls, _mobilenet_v3_conf from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional from typing import Any, List, Optional
from .utils import _replace_relu from .utils import _replace_relu
...@@ -16,16 +17,53 @@ quant_model_urls = { ...@@ -16,16 +17,53 @@ quant_model_urls = {
} }
class QuantizableSqueezeExcitation(SqueezeExcitation): class QuantizableSqueezeExcitation(SElayer):
_version = 2
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["scale_activation"] = nn.Hardsigmoid
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional() self.skip_mul = nn.quantized.FloatFunctional()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input, False), input) return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None: def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'relu'], inplace=True) fuse_modules(self, ['fc1', 'activation'], inplace=True)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
default_state_dict = {
"scale_activation.activation_post_process.scale": torch.tensor([1.]),
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
}
for k, v in default_state_dict.items():
full_key = prefix + k
if full_key not in state_dict:
state_dict[full_key] = v
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class QuantizableInvertedResidual(InvertedResidual): class QuantizableInvertedResidual(InvertedResidual):
...@@ -78,7 +116,7 @@ def _load_weights( ...@@ -78,7 +116,7 @@ def _load_weights(
arch: str, arch: str,
model: QuantizableMobileNetV3, model: QuantizableMobileNetV3,
model_url: Optional[str], model_url: Optional[str],
progress: bool, progress: bool
) -> None: ) -> None:
if model_url is None: if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch)) raise ValueError("No checkpoint is available for {}".format(arch))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment