Unverified Commit f15ba56f authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add Conv2dNormActivation and Conv3dNormActivation Blocks (#5445)



* Add ops.conv3d

* Refactor for conv2d and 3d

* Refactor

* Fix bug

* Addres review

* Fix bug

* nit fix

* Fix flake

* Final fix

* remove documentation

* fix linter

* Update all the implementations to use new Conv

* Small doc fix
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent cdcb8b6b
...@@ -45,4 +45,6 @@ Operators ...@@ -45,4 +45,6 @@ Operators
FeaturePyramidNetwork FeaturePyramidNetwork
StochasticDepth StochasticDepth
FrozenBatchNorm2d FrozenBatchNorm2d
Conv2dNormActivation
Conv3dNormActivation
SqueezeExcitation SqueezeExcitation
...@@ -6,7 +6,7 @@ from torch import nn, Tensor ...@@ -6,7 +6,7 @@ from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation from ..ops.misc import Conv2dNormActivation
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
...@@ -127,7 +127,7 @@ class ConvNeXt(nn.Module): ...@@ -127,7 +127,7 @@ class ConvNeXt(nn.Module):
# Stem # Stem
firstconv_output_channels = block_setting[0].input_channels firstconv_output_channels = block_setting[0].input_channels
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
3, 3,
firstconv_output_channels, firstconv_output_channels,
kernel_size=4, kernel_size=4,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation from ...ops.misc import Conv2dNormActivation
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .. import mobilenet from .. import mobilenet
from . import _utils as det_utils from . import _utils as det_utils
...@@ -29,7 +29,7 @@ def _prediction_block( ...@@ -29,7 +29,7 @@ def _prediction_block(
) -> nn.Sequential: ) -> nn.Sequential:
return nn.Sequential( return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1 # 3x3 depthwise with stride 1 and padding 1
ConvNormActivation( Conv2dNormActivation(
in_channels, in_channels,
in_channels, in_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
...@@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., ...@@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2 intermediate_channels = out_channels // 2
return nn.Sequential( return nn.Sequential(
# 1x1 projection to half output channels # 1x1 projection to half output channels
ConvNormActivation( Conv2dNormActivation(
in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
), ),
# 3x3 depthwise with stride 2 and padding 1 # 3x3 depthwise with stride 2 and padding 1
ConvNormActivation( Conv2dNormActivation(
intermediate_channels, intermediate_channels,
intermediate_channels, intermediate_channels,
kernel_size=3, kernel_size=3,
...@@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., ...@@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
activation_layer=activation, activation_layer=activation,
), ),
# 1x1 projetion to output channels # 1x1 projetion to output channels
ConvNormActivation( Conv2dNormActivation(
intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
), ),
) )
......
...@@ -8,7 +8,7 @@ from torch import nn, Tensor ...@@ -8,7 +8,7 @@ from torch import nn, Tensor
from torchvision.ops import StochasticDepth from torchvision.ops import StochasticDepth
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _make_divisible from ._utils import _make_divisible
...@@ -104,7 +104,7 @@ class MBConv(nn.Module): ...@@ -104,7 +104,7 @@ class MBConv(nn.Module):
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels: if expanded_channels != cnf.input_channels:
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
cnf.input_channels, cnf.input_channels,
expanded_channels, expanded_channels,
kernel_size=1, kernel_size=1,
...@@ -115,7 +115,7 @@ class MBConv(nn.Module): ...@@ -115,7 +115,7 @@ class MBConv(nn.Module):
# depthwise # depthwise
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
expanded_channels, expanded_channels,
expanded_channels, expanded_channels,
kernel_size=cnf.kernel, kernel_size=cnf.kernel,
...@@ -132,7 +132,7 @@ class MBConv(nn.Module): ...@@ -132,7 +132,7 @@ class MBConv(nn.Module):
# project # project
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
) )
) )
...@@ -193,7 +193,7 @@ class EfficientNet(nn.Module): ...@@ -193,7 +193,7 @@ class EfficientNet(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
) )
) )
...@@ -224,7 +224,7 @@ class EfficientNet(nn.Module): ...@@ -224,7 +224,7 @@ class EfficientNet(nn.Module):
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels lastconv_output_channels = 4 * lastconv_input_channels
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
lastconv_input_channels, lastconv_input_channels,
lastconv_output_channels, lastconv_output_channels,
kernel_size=1, kernel_size=1,
......
...@@ -6,7 +6,7 @@ from torch import Tensor ...@@ -6,7 +6,7 @@ from torch import Tensor
from torch import nn from torch import nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation from ..ops.misc import Conv2dNormActivation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _make_divisible from ._utils import _make_divisible
...@@ -20,11 +20,11 @@ model_urls = { ...@@ -20,11 +20,11 @@ model_urls = {
# necessary for backwards compatibility # necessary for backwards compatibility
class _DeprecatedConvBNAct(ConvNormActivation): class _DeprecatedConvBNAct(Conv2dNormActivation):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. " "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. "
"Use torchvision.ops.misc.ConvNormActivation instead.", "Use torchvision.ops.misc.Conv2dNormActivation instead.",
FutureWarning, FutureWarning,
) )
if kwargs.get("norm_layer", None) is None: if kwargs.get("norm_layer", None) is None:
...@@ -56,12 +56,12 @@ class InvertedResidual(nn.Module): ...@@ -56,12 +56,12 @@ class InvertedResidual(nn.Module):
if expand_ratio != 1: if expand_ratio != 1:
# pw # pw
layers.append( layers.append(
ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
) )
layers.extend( layers.extend(
[ [
# dw # dw
ConvNormActivation( Conv2dNormActivation(
hidden_dim, hidden_dim,
hidden_dim, hidden_dim,
stride=stride, stride=stride,
...@@ -144,7 +144,7 @@ class MobileNetV2(nn.Module): ...@@ -144,7 +144,7 @@ class MobileNetV2(nn.Module):
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ features: List[nn.Module] = [
ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
] ]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in inverted_residual_setting: for t, c, n, s in inverted_residual_setting:
...@@ -155,7 +155,7 @@ class MobileNetV2(nn.Module): ...@@ -155,7 +155,7 @@ class MobileNetV2(nn.Module):
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
features.append( features.append(
ConvNormActivation( Conv2dNormActivation(
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
) )
) )
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _make_divisible from ._utils import _make_divisible
...@@ -83,7 +83,7 @@ class InvertedResidual(nn.Module): ...@@ -83,7 +83,7 @@ class InvertedResidual(nn.Module):
# expand # expand
if cnf.expanded_channels != cnf.input_channels: if cnf.expanded_channels != cnf.input_channels:
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
cnf.input_channels, cnf.input_channels,
cnf.expanded_channels, cnf.expanded_channels,
kernel_size=1, kernel_size=1,
...@@ -95,7 +95,7 @@ class InvertedResidual(nn.Module): ...@@ -95,7 +95,7 @@ class InvertedResidual(nn.Module):
# depthwise # depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
cnf.expanded_channels, cnf.expanded_channels,
cnf.expanded_channels, cnf.expanded_channels,
kernel_size=cnf.kernel, kernel_size=cnf.kernel,
...@@ -112,7 +112,7 @@ class InvertedResidual(nn.Module): ...@@ -112,7 +112,7 @@ class InvertedResidual(nn.Module):
# project # project
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
) )
) )
...@@ -172,7 +172,7 @@ class MobileNetV3(nn.Module): ...@@ -172,7 +172,7 @@ class MobileNetV3(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
3, 3,
firstconv_output_channels, firstconv_output_channels,
kernel_size=3, kernel_size=3,
...@@ -190,7 +190,7 @@ class MobileNetV3(nn.Module): ...@@ -190,7 +190,7 @@ class MobileNetV3(nn.Module):
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels lastconv_output_channels = 6 * lastconv_input_channels
layers.append( layers.append(
ConvNormActivation( Conv2dNormActivation(
lastconv_input_channels, lastconv_input_channels,
lastconv_output_channels, lastconv_output_channels,
kernel_size=1, kernel_size=1,
......
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops.misc import ConvNormActivation from torchvision.ops import Conv2dNormActivation
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
...@@ -38,17 +38,17 @@ class ResidualBlock(nn.Module): ...@@ -38,17 +38,17 @@ class ResidualBlock(nn.Module):
# and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful
# for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm
# because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights). # because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights).
self.convnormrelu1 = ConvNormActivation( self.convnormrelu1 = Conv2dNormActivation(
in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
) )
self.convnormrelu2 = ConvNormActivation( self.convnormrelu2 = Conv2dNormActivation(
out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
) )
if stride == 1: if stride == 1:
self.downsample = nn.Identity() self.downsample = nn.Identity()
else: else:
self.downsample = ConvNormActivation( self.downsample = Conv2dNormActivation(
in_channels, in_channels,
out_channels, out_channels,
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -77,13 +77,13 @@ class BottleneckBlock(nn.Module): ...@@ -77,13 +77,13 @@ class BottleneckBlock(nn.Module):
super().__init__() super().__init__()
# See note in ResidualBlock for the reason behind bias=True # See note in ResidualBlock for the reason behind bias=True
self.convnormrelu1 = ConvNormActivation( self.convnormrelu1 = Conv2dNormActivation(
in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True
) )
self.convnormrelu2 = ConvNormActivation( self.convnormrelu2 = Conv2dNormActivation(
out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
) )
self.convnormrelu3 = ConvNormActivation( self.convnormrelu3 = Conv2dNormActivation(
out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True
) )
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
...@@ -91,7 +91,7 @@ class BottleneckBlock(nn.Module): ...@@ -91,7 +91,7 @@ class BottleneckBlock(nn.Module):
if stride == 1: if stride == 1:
self.downsample = nn.Identity() self.downsample = nn.Identity()
else: else:
self.downsample = ConvNormActivation( self.downsample = Conv2dNormActivation(
in_channels, in_channels,
out_channels, out_channels,
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -124,7 +124,9 @@ class FeatureEncoder(nn.Module): ...@@ -124,7 +124,9 @@ class FeatureEncoder(nn.Module):
assert len(layers) == 5 assert len(layers) == 5
# See note in ResidualBlock for the reason behind bias=True # See note in ResidualBlock for the reason behind bias=True
self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) self.convnormrelu = Conv2dNormActivation(
3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True
)
self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1)
self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2)
...@@ -170,17 +172,17 @@ class MotionEncoder(nn.Module): ...@@ -170,17 +172,17 @@ class MotionEncoder(nn.Module):
assert len(flow_layers) == 2 assert len(flow_layers) == 2
assert len(corr_layers) in (1, 2) assert len(corr_layers) in (1, 2)
self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
if len(corr_layers) == 2: if len(corr_layers) == 2:
self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
else: else:
self.convcorr2 = nn.Identity() self.convcorr2 = nn.Identity()
self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)
# out_channels - 2 because we cat the flow (2 channels) at the end # out_channels - 2 because we cat the flow (2 channels) at the end
self.conv = ConvNormActivation( self.conv = Conv2dNormActivation(
corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3
) )
...@@ -301,7 +303,7 @@ class MaskPredictor(nn.Module): ...@@ -301,7 +303,7 @@ class MaskPredictor(nn.Module):
def __init__(self, *, in_channels, hidden_size, multiplier=0.25): def __init__(self, *, in_channels, hidden_size, multiplier=0.25):
super().__init__() super().__init__()
self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
# 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder
# and we interpolate with all 9 surrounding neighbors. See paper and appendix B. # and we interpolate with all 9 surrounding neighbors. See paper and appendix B.
self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0)
......
...@@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub ...@@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation from ...ops.misc import Conv2dNormActivation
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
...@@ -54,7 +54,7 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -54,7 +54,7 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self, is_qat: Optional[bool] = None) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules(): for m in self.modules():
if type(m) is ConvNormActivation: if type(m) is Conv2dNormActivation:
_fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
if type(m) is QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
m.fuse_model(is_qat) m.fuse_model(is_qat)
......
...@@ -5,7 +5,7 @@ from torch import nn, Tensor ...@@ -5,7 +5,7 @@ from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub from torch.ao.quantization import QuantStub, DeQuantStub
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
from .utils import _fuse_modules, _replace_relu from .utils import _fuse_modules, _replace_relu
...@@ -103,7 +103,7 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -103,7 +103,7 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self, is_qat: Optional[bool] = None) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules(): for m in self.modules():
if type(m) is ConvNormActivation: if type(m) is Conv2dNormActivation:
modules_to_fuse = ["0", "1"] modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) is nn.ReLU: if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2") modules_to_fuse.append("2")
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _make_divisible from ._utils import _make_divisible
...@@ -55,7 +55,7 @@ model_urls = { ...@@ -55,7 +55,7 @@ model_urls = {
} }
class SimpleStemIN(ConvNormActivation): class SimpleStemIN(Conv2dNormActivation):
"""Simple stem for ImageNet: 3x3, BN, ReLU.""" """Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__( def __init__(
...@@ -88,10 +88,10 @@ class BottleneckTransform(nn.Sequential): ...@@ -88,10 +88,10 @@ class BottleneckTransform(nn.Sequential):
w_b = int(round(width_out * bottleneck_multiplier)) w_b = int(round(width_out * bottleneck_multiplier))
g = w_b // group_width g = w_b // group_width
layers["a"] = ConvNormActivation( layers["a"] = Conv2dNormActivation(
width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
) )
layers["b"] = ConvNormActivation( layers["b"] = Conv2dNormActivation(
w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
) )
...@@ -105,7 +105,7 @@ class BottleneckTransform(nn.Sequential): ...@@ -105,7 +105,7 @@ class BottleneckTransform(nn.Sequential):
activation=activation_layer, activation=activation_layer,
) )
layers["c"] = ConvNormActivation( layers["c"] = Conv2dNormActivation(
w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
) )
super().__init__(layers) super().__init__(layers)
...@@ -131,7 +131,7 @@ class ResBottleneckBlock(nn.Module): ...@@ -131,7 +131,7 @@ class ResBottleneckBlock(nn.Module):
self.proj = None self.proj = None
should_proj = (width_in != width_out) or (stride != 1) should_proj = (width_in != width_out) or (stride != 1)
if should_proj: if should_proj:
self.proj = ConvNormActivation( self.proj = Conv2dNormActivation(
width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
) )
self.f = BottleneckTransform( self.f = BottleneckTransform(
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation from ..ops.misc import Conv2dNormActivation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
__all__ = [ __all__ = [
...@@ -163,7 +163,7 @@ class VisionTransformer(nn.Module): ...@@ -163,7 +163,7 @@ class VisionTransformer(nn.Module):
for i, conv_stem_layer_config in enumerate(conv_stem_configs): for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module( seq_proj.add_module(
f"conv_bn_relu_{i}", f"conv_bn_relu_{i}",
ConvNormActivation( Conv2dNormActivation(
in_channels=prev_channels, in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels, out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size, kernel_size=conv_stem_layer_config.kernel_size,
......
...@@ -14,7 +14,7 @@ from .deform_conv import deform_conv2d, DeformConv2d ...@@ -14,7 +14,7 @@ from .deform_conv import deform_conv2d, DeformConv2d
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, SqueezeExcitation from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool from .ps_roi_pool import ps_roi_pool, PSRoIPool
...@@ -51,6 +51,8 @@ __all__ = [ ...@@ -51,6 +51,8 @@ __all__ = [
"stochastic_depth", "stochastic_depth",
"StochasticDepth", "StochasticDepth",
"FrozenBatchNorm2d", "FrozenBatchNorm2d",
"Conv2dNormActivation",
"Conv3dNormActivation",
"SqueezeExcitation", "SqueezeExcitation",
"generalized_box_iou_loss", "generalized_box_iou_loss",
] ]
import warnings
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
...@@ -66,24 +67,6 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -66,24 +67,6 @@ class FrozenBatchNorm2d(torch.nn.Module):
class ConvNormActivation(torch.nn.Sequential): class ConvNormActivation(torch.nn.Sequential):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -97,13 +80,16 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -97,13 +80,16 @@ class ConvNormActivation(torch.nn.Sequential):
dilation: int = 1, dilation: int = 1,
inplace: Optional[bool] = True, inplace: Optional[bool] = True,
bias: Optional[bool] = None, bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None: ) -> None:
if padding is None: if padding is None:
padding = (kernel_size - 1) // 2 * dilation padding = (kernel_size - 1) // 2 * dilation
if bias is None: if bias is None:
bias = norm_layer is None bias = norm_layer is None
layers = [ layers = [
torch.nn.Conv2d( conv_layer(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -114,8 +100,10 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -114,8 +100,10 @@ class ConvNormActivation(torch.nn.Sequential):
bias=bias, bias=bias,
) )
] ]
if norm_layer is not None: if norm_layer is not None:
layers.append(norm_layer(out_channels)) layers.append(norm_layer(out_channels))
if activation_layer is not None: if activation_layer is not None:
params = {} if inplace is None else {"inplace": inplace} params = {} if inplace is None else {"inplace": inplace}
layers.append(activation_layer(**params)) layers.append(activation_layer(**params))
...@@ -123,6 +111,110 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -123,6 +111,110 @@ class ConvNormActivation(torch.nn.Sequential):
_log_api_usage_once(self) _log_api_usage_once(self)
self.out_channels = out_channels self.out_channels = out_channels
if self.__class__ == ConvNormActivation:
warnings.warn(
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
)
class Conv2dNormActivation(ConvNormActivation):
"""
Configurable block used for Convolution2d-Normalzation-Activation blocks.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
torch.nn.Conv2d,
)
class Conv3dNormActivation(ConvNormActivation):
"""
Configurable block used for Convolution3d-Normalzation-Activation blocks.
Args:
in_channels (int): Number of channels in the input video.
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
torch.nn.Conv3d,
)
class SqueezeExcitation(torch.nn.Module): class SqueezeExcitation(torch.nn.Module):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment