Unverified Commit f7498350 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Moving common layers to ops (#4504)

* Moving _make_divisible to utils.

* Replace the old ConvBNReLU and ConvBNActivation layers

* Fix minor bug.

* Moving SE layer to ops.

* Adding deprecation warnings on old layers.

* Apply changes to regnets.
parent 5760f356
from collections import OrderedDict
from torch import nn
from typing import Dict
from typing import Dict, Optional
class IntermediateLayerGetter(nn.ModuleDict):
......@@ -64,3 +64,19 @@ class IntermediateLayerGetter(nn.ModuleDict):
out_name = self.return_layers[name]
out[out_name] = x
return out
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
......@@ -11,8 +11,8 @@ from .ssd import SSD, SSDScoringHead
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .. import mobilenet
from ..mobilenetv3 import ConvBNActivation
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
__all__ = ['ssdlite320_mobilenet_v3_large']
......@@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),
ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),
# 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1)
......@@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2
return nn.Sequential(
# 1x1 projection to half output channels
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
# 3x3 depthwise with stride 2 and padding 1
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
# 1x1 projetion to output channels
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
)
......
......@@ -4,14 +4,13 @@ import torch
from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible
from torchvision.ops import StochasticDepth
from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
......@@ -31,32 +30,6 @@ model_urls = {
}
class SqueezeExcitation(nn.Module):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., nn.Module] = nn.ReLU,
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
) -> None:
super().__init__()
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper
def __init__(self,
......@@ -106,21 +79,21 @@ class MBConv(nn.Module):
# expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise
layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
# squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))
self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
......@@ -174,8 +147,8 @@ class EfficientNet(nn.Module):
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU))
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU))
# building inverted residual blocks
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
......@@ -202,8 +175,8 @@ class EfficientNet(nn.Module):
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU))
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU))
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
......
import torch
import warnings
from functools import partial
from torch import nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ._utils import _make_divisible
from typing import Callable, Any, Optional, List
......@@ -13,50 +18,21 @@ model_urls = {
}
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNActivation(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes),
activation_layer(inplace=True)
)
self.out_channels = out_planes
# necessary for backwards compatibility
class _DeprecatedConvBNAct(ConvNormActivation):
def __init__(self, *args, **kwargs):
warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning)
if kwargs.get("norm_layer", None) is None:
kwargs["norm_layer"] = nn.BatchNorm2d
if kwargs.get("activation_layer", None) is None:
kwargs["activation_layer"] = nn.ReLU6
super().__init__(*args, **kwargs)
# necessary for backwards compatibility
ConvBNReLU = ConvBNActivation
ConvBNReLU = _DeprecatedConvBNAct
ConvBNActivation = _DeprecatedConvBNAct
class InvertedResidual(nn.Module):
......@@ -81,10 +57,12 @@ class InvertedResidual(nn.Module):
layers: List[nn.Module] = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
activation_layer=nn.ReLU6),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
......@@ -154,7 +132,8 @@ class MobileNetV2(nn.Module):
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,
activation_layer=nn.ReLU6)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
......@@ -163,7 +142,8 @@ class MobileNetV2(nn.Module):
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
# make it nn.Sequential
self.features = nn.Sequential(*features)
......
......@@ -6,8 +6,8 @@ from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence
from .._internally_replaced_utils import load_state_dict_from_url
from .efficientnet import SqueezeExcitation as SElayer
from .mobilenetv2 import _make_divisible, ConvBNActivation
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from ._utils import _make_divisible
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
......@@ -28,7 +28,8 @@ class SqueezeExcitation(SElayer):
self.relu = self.activation
delattr(self, 'activation')
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
"This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning)
class InvertedResidualConfig:
......@@ -64,21 +65,21 @@ class InvertedResidual(nn.Module):
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
......@@ -130,8 +131,8 @@ class MobileNetV3(nn.Module):
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
# building inverted residual blocks
for cnf in inverted_residual_setting:
......@@ -140,8 +141,8 @@ class MobileNetV3(nn.Module):
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
......
......@@ -5,9 +5,10 @@ from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
from ...ops.misc import ConvNormActivation
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
......@@ -55,7 +56,7 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNReLU:
if type(m) == ConvNormActivation:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QuantizableInvertedResidual:
m.fuse_model()
......
import torch
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ..efficientnet import SqueezeExcitation as SElayer
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
......@@ -17,7 +17,7 @@ quant_model_urls = {
}
class QuantizableSqueezeExcitation(SElayer):
class QuantizableSqueezeExcitation(SqueezeExcitation):
_version = 2
def __init__(self, *args: Any, **kwargs: Any) -> None:
......@@ -103,9 +103,9 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNActivation:
if type(m) == ConvNormActivation:
modules_to_fuse = ['0', '1']
if type(m[2]) == nn.ReLU:
if len(m) == 3 and type(m[2]) == nn.ReLU:
modules_to_fuse.append('2')
fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation:
......
......@@ -12,8 +12,8 @@ from typing import Any, Callable, List, Optional, Tuple
from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible
from torchvision.models.efficientnet import SqueezeExcitation
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible
__all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf",
......@@ -32,7 +32,7 @@ model_urls = {
}
class SimpleStemIN(ConvBNActivation):
class SimpleStemIN(ConvNormActivation):
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__(
......@@ -64,10 +64,10 @@ class BottleneckTransform(nn.Sequential):
w_b = int(round(width_out * bottleneck_multiplier))
g = w_b // group_width
layers["a"] = ConvBNActivation(width_in, w_b, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=activation_layer)
layers["b"] = ConvBNActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g,
norm_layer=norm_layer, activation_layer=activation_layer)
layers["a"] = ConvNormActivation(width_in, w_b, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=activation_layer)
layers["b"] = ConvNormActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g,
norm_layer=norm_layer, activation_layer=activation_layer)
if se_ratio:
# The SE reduction ratio is defined with respect to the
......@@ -79,8 +79,8 @@ class BottleneckTransform(nn.Sequential):
activation=activation_layer,
)
layers["c"] = ConvBNActivation(w_b, width_out, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=nn.Identity)
layers["c"] = ConvNormActivation(w_b, width_out, kernel_size=1, stride=1,
norm_layer=norm_layer, activation_layer=None)
super().__init__(layers)
......@@ -104,8 +104,8 @@ class ResBottleneckBlock(nn.Module):
self.proj = None
should_proj = (width_in != width_out) or (stride != 1)
if should_proj:
self.proj = ConvBNActivation(width_in, width_out, kernel_size=1,
stride=stride, norm_layer=norm_layer, activation_layer=nn.Identity)
self.proj = ConvNormActivation(width_in, width_out, kernel_size=1,
stride=stride, norm_layer=norm_layer, activation_layer=None)
self.f = BottleneckTransform(
width_in,
width_out,
......
......@@ -11,7 +11,7 @@ is implemented
import warnings
import torch
from torch import Tensor
from typing import List, Optional
from typing import Callable, List, Optional
class Conv2d(torch.nn.Conv2d):
......@@ -97,3 +97,56 @@ class FrozenBatchNorm2d(torch.nn.Module):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
class ConvNormActivation(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: bool = True,
) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation=dilation, groups=groups, bias=norm_layer is None)]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
layers.append(activation_layer(inplace=inplace))
super().__init__(*layers)
self.out_channels = out_channels
class SqueezeExcitation(torch.nn.Module):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None:
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment