Unverified Commit f7498350 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Moving common layers to ops (#4504)

* Moving _make_divisible to utils.

* Replace the old ConvBNReLU and ConvBNActivation layers

* Fix minor bug.

* Moving SE layer to ops.

* Adding deprecation warnings on old layers.

* Apply changes to regnets.
parent 5760f356
from collections import OrderedDict from collections import OrderedDict
from torch import nn from torch import nn
from typing import Dict from typing import Dict, Optional
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.ModuleDict):
...@@ -64,3 +64,19 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -64,3 +64,19 @@ class IntermediateLayerGetter(nn.ModuleDict):
out_name = self.return_layers[name] out_name = self.return_layers[name]
out[out_name] = x out[out_name] = x
return out return out
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
...@@ -11,8 +11,8 @@ from .ssd import SSD, SSDScoringHead ...@@ -11,8 +11,8 @@ from .ssd import SSD, SSDScoringHead
from .anchor_utils import DefaultBoxGenerator from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers from .backbone_utils import _validate_trainable_layers
from .. import mobilenet from .. import mobilenet
from ..mobilenetv3 import ConvBNActivation
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
__all__ = ['ssdlite320_mobilenet_v3_large'] __all__ = ['ssdlite320_mobilenet_v3_large']
...@@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, ...@@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
norm_layer: Callable[..., nn.Module]) -> nn.Sequential: norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
return nn.Sequential( return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1 # 3x3 depthwise with stride 1 and padding 1
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6), norm_layer=norm_layer, activation_layer=nn.ReLU6),
# 1x1 projetion to output channels # 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1) nn.Conv2d(in_channels, out_channels, 1)
...@@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., ...@@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2 intermediate_channels = out_channels // 2
return nn.Sequential( return nn.Sequential(
# 1x1 projection to half output channels # 1x1 projection to half output channels
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, ConvNormActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation), norm_layer=norm_layer, activation_layer=activation),
# 3x3 depthwise with stride 2 and padding 1 # 3x3 depthwise with stride 2 and padding 1
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
# 1x1 projetion to output channels # 1x1 projetion to output channels
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, ConvNormActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation), norm_layer=norm_layer, activation_layer=activation),
) )
......
...@@ -4,14 +4,13 @@ import torch ...@@ -4,14 +4,13 @@ import torch
from functools import partial from functools import partial
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible
from torchvision.ops import StochasticDepth from torchvision.ops import StochasticDepth
from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", __all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"] "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
...@@ -31,32 +30,6 @@ model_urls = { ...@@ -31,32 +30,6 @@ model_urls = {
} }
class SqueezeExcitation(nn.Module):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., nn.Module] = nn.ReLU,
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
) -> None:
super().__init__()
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
class MBConvConfig: class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper # Stores information listed at Table 1 of the EfficientNet paper
def __init__(self, def __init__(self,
...@@ -106,21 +79,21 @@ class MBConv(nn.Module): ...@@ -106,21 +79,21 @@ class MBConv(nn.Module):
# expand # expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels: if expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1, layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise # depthwise
layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels, stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
# squeeze and excitation # squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4) squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project # project
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity)) activation_layer=None))
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
...@@ -174,8 +147,8 @@ class EfficientNet(nn.Module): ...@@ -174,8 +147,8 @@ class EfficientNet(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU)) activation_layer=nn.SiLU))
# building inverted residual blocks # building inverted residual blocks
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
...@@ -202,8 +175,8 @@ class EfficientNet(nn.Module): ...@@ -202,8 +175,8 @@ class EfficientNet(nn.Module):
# building last several layers # building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels lastconv_output_channels = 4 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU)) norm_layer=norm_layer, activation_layer=nn.SiLU))
self.features = nn.Sequential(*layers) self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
......
import torch import torch
import warnings
from functools import partial
from torch import nn from torch import nn
from torch import Tensor from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ._utils import _make_divisible
from typing import Callable, Any, Optional, List from typing import Callable, Any, Optional, List
...@@ -13,50 +18,21 @@ model_urls = { ...@@ -13,50 +18,21 @@ model_urls = {
} }
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: # necessary for backwards compatibility
""" class _DeprecatedConvBNAct(ConvNormActivation):
This function is taken from the original tf repo. def __init__(self, *args, **kwargs):
It ensures that all layers have a channel number that is divisible by 8 warnings.warn(
It can be seen here: "The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py "Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning)
""" if kwargs.get("norm_layer", None) is None:
if min_value is None: kwargs["norm_layer"] = nn.BatchNorm2d
min_value = divisor if kwargs.get("activation_layer", None) is None:
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) kwargs["activation_layer"] = nn.ReLU6
# Make sure that round down does not go down by more than 10%. super().__init__(*args, **kwargs)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNActivation(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes),
activation_layer(inplace=True)
)
self.out_channels = out_planes
# necessary for backwards compatibility ConvBNReLU = _DeprecatedConvBNAct
ConvBNReLU = ConvBNActivation ConvBNActivation = _DeprecatedConvBNAct
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
...@@ -81,10 +57,12 @@ class InvertedResidual(nn.Module): ...@@ -81,10 +57,12 @@ class InvertedResidual(nn.Module):
layers: List[nn.Module] = [] layers: List[nn.Module] = []
if expand_ratio != 1: if expand_ratio != 1:
# pw # pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
layers.extend([ layers.extend([
# dw # dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
activation_layer=nn.ReLU6),
# pw-linear # pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup), norm_layer(oup),
...@@ -154,7 +132,8 @@ class MobileNetV2(nn.Module): ...@@ -154,7 +132,8 @@ class MobileNetV2(nn.Module):
# building first layer # building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,
activation_layer=nn.ReLU6)]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in inverted_residual_setting: for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest) output_channel = _make_divisible(c * width_mult, round_nearest)
...@@ -163,7 +142,8 @@ class MobileNetV2(nn.Module): ...@@ -163,7 +142,8 @@ class MobileNetV2(nn.Module):
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
# make it nn.Sequential # make it nn.Sequential
self.features = nn.Sequential(*features) self.features = nn.Sequential(*features)
......
...@@ -6,8 +6,8 @@ from torch import nn, Tensor ...@@ -6,8 +6,8 @@ from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from .efficientnet import SqueezeExcitation as SElayer from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from .mobilenetv2 import _make_divisible, ConvBNActivation from ._utils import _make_divisible
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
...@@ -28,7 +28,8 @@ class SqueezeExcitation(SElayer): ...@@ -28,7 +28,8 @@ class SqueezeExcitation(SElayer):
self.relu = self.activation self.relu = self.activation
delattr(self, 'activation') delattr(self, 'activation')
warnings.warn( warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning) "This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning)
class InvertedResidualConfig: class InvertedResidualConfig:
...@@ -64,21 +65,21 @@ class InvertedResidual(nn.Module): ...@@ -64,21 +65,21 @@ class InvertedResidual(nn.Module):
# expand # expand
if cnf.expanded_channels != cnf.input_channels: if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise # depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se: if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project # project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity)) activation_layer=None))
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels self.out_channels = cnf.out_channels
...@@ -130,8 +131,8 @@ class MobileNetV3(nn.Module): ...@@ -130,8 +131,8 @@ class MobileNetV3(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish)) activation_layer=nn.Hardswish))
# building inverted residual blocks # building inverted residual blocks
for cnf in inverted_residual_setting: for cnf in inverted_residual_setting:
...@@ -140,8 +141,8 @@ class MobileNetV3(nn.Module): ...@@ -140,8 +141,8 @@ class MobileNetV3(nn.Module):
# building last several layers # building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish)) norm_layer=norm_layer, activation_layer=nn.Hardswish))
self.features = nn.Sequential(*layers) self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
......
...@@ -5,9 +5,10 @@ from ..._internally_replaced_utils import load_state_dict_from_url ...@@ -5,9 +5,10 @@ from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
from ...ops.misc import ConvNormActivation
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] __all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
...@@ -55,7 +56,7 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -55,7 +56,7 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvBNReLU: if type(m) == ConvNormActivation:
fuse_modules(m, ['0', '1', '2'], inplace=True) fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QuantizableInvertedResidual: if type(m) == QuantizableInvertedResidual:
m.fuse_model() m.fuse_model()
......
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ..efficientnet import SqueezeExcitation as SElayer from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\
model_urls, _mobilenet_v3_conf model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional from typing import Any, List, Optional
...@@ -17,7 +17,7 @@ quant_model_urls = { ...@@ -17,7 +17,7 @@ quant_model_urls = {
} }
class QuantizableSqueezeExcitation(SElayer): class QuantizableSqueezeExcitation(SqueezeExcitation):
_version = 2 _version = 2
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
...@@ -103,9 +103,9 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -103,9 +103,9 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvBNActivation: if type(m) == ConvNormActivation:
modules_to_fuse = ['0', '1'] modules_to_fuse = ['0', '1']
if type(m[2]) == nn.ReLU: if len(m) == 3 and type(m[2]) == nn.ReLU:
modules_to_fuse.append('2') modules_to_fuse.append('2')
fuse_modules(m, modules_to_fuse, inplace=True) fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation: elif type(m) == QuantizableSqueezeExcitation:
......
...@@ -12,8 +12,8 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -12,8 +12,8 @@ from typing import Any, Callable, List, Optional, Tuple
from torch import nn, Tensor from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible from ..ops.misc import ConvNormActivation, SqueezeExcitation
from torchvision.models.efficientnet import SqueezeExcitation from ._utils import _make_divisible
__all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", __all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf",
...@@ -32,7 +32,7 @@ model_urls = { ...@@ -32,7 +32,7 @@ model_urls = {
} }
class SimpleStemIN(ConvBNActivation): class SimpleStemIN(ConvNormActivation):
"""Simple stem for ImageNet: 3x3, BN, ReLU.""" """Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__( def __init__(
...@@ -64,10 +64,10 @@ class BottleneckTransform(nn.Sequential): ...@@ -64,10 +64,10 @@ class BottleneckTransform(nn.Sequential):
w_b = int(round(width_out * bottleneck_multiplier)) w_b = int(round(width_out * bottleneck_multiplier))
g = w_b // group_width g = w_b // group_width
layers["a"] = ConvBNActivation(width_in, w_b, kernel_size=1, stride=1, layers["a"] = ConvNormActivation(width_in, w_b, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=activation_layer) norm_layer=norm_layer, activation_layer=activation_layer)
layers["b"] = ConvBNActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g, layers["b"] = ConvNormActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g,
norm_layer=norm_layer, activation_layer=activation_layer) norm_layer=norm_layer, activation_layer=activation_layer)
if se_ratio: if se_ratio:
# The SE reduction ratio is defined with respect to the # The SE reduction ratio is defined with respect to the
...@@ -79,8 +79,8 @@ class BottleneckTransform(nn.Sequential): ...@@ -79,8 +79,8 @@ class BottleneckTransform(nn.Sequential):
activation=activation_layer, activation=activation_layer,
) )
layers["c"] = ConvBNActivation(w_b, width_out, kernel_size=1, stride=1, layers["c"] = ConvNormActivation(w_b, width_out, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=nn.Identity) norm_layer=norm_layer, activation_layer=None)
super().__init__(layers) super().__init__(layers)
...@@ -104,8 +104,8 @@ class ResBottleneckBlock(nn.Module): ...@@ -104,8 +104,8 @@ class ResBottleneckBlock(nn.Module):
self.proj = None self.proj = None
should_proj = (width_in != width_out) or (stride != 1) should_proj = (width_in != width_out) or (stride != 1)
if should_proj: if should_proj:
self.proj = ConvBNActivation(width_in, width_out, kernel_size=1, self.proj = ConvNormActivation(width_in, width_out, kernel_size=1,
stride=stride, norm_layer=norm_layer, activation_layer=nn.Identity) stride=stride, norm_layer=norm_layer, activation_layer=None)
self.f = BottleneckTransform( self.f = BottleneckTransform(
width_in, width_in,
width_out, width_out,
......
...@@ -11,7 +11,7 @@ is implemented ...@@ -11,7 +11,7 @@ is implemented
import warnings import warnings
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Optional from typing import Callable, List, Optional
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
...@@ -97,3 +97,56 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -97,3 +97,56 @@ class FrozenBatchNorm2d(torch.nn.Module):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
class ConvNormActivation(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: bool = True,
) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation=dilation, groups=groups, bias=norm_layer is None)]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
layers.append(activation_layer(inplace=inplace))
super().__init__(*layers)
self.out_channels = out_channels
class SqueezeExcitation(torch.nn.Module):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None:
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment