Unverified Commit 49e32c26 authored by Xiaojie Li's avatar Xiaojie Li Committed by GitHub
Browse files

Implementation of 2D convolution in tensorflow with `padding` as "same" (#529)



* update impad

* fix docstring

* add shape for impad

* fix unit test

* remove old version & fix doc

* fix linting

* fix doc

* add linear decay learning rate scheduler

* fix impad

* fix setup.cfg

* fix linting

* add yapf

* add swish

* fix lr_updater

* fix lr_updater.py

* update swish

* add swish

* fix inplace

* fix typo

* update

* add same padding

* fix docstring

* add unittest

* fix register

* change name
Co-authored-by: default avatarlixiaojie <lixiaojie@sensetime.com>
parent d7c895a3
from .activation import build_activation_layer from .activation import build_activation_layer
from .context_block import ContextBlock from .context_block import ContextBlock
from .conv import build_conv_layer from .conv import build_conv_layer
from .conv2d_adaptive_padding import Conv2dAdaptivePadding
from .conv_module import ConvModule from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
...@@ -24,5 +25,6 @@ __all__ = [ ...@@ -24,5 +25,6 @@ __all__ = [
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish' 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish',
'Conv2dAdaptivePadding'
] ]
import math
from torch import nn
from torch.nn import functional as F
from .registry import CONV_LAYERS
@CONV_LAYERS.register_module()
class Conv2dAdaptivePadding(nn.Conv2d):
""" Implementation of 2D convolution in tensorflow with `padding` as
"same", which applies padding to input (if needed) so that input image
gets fully covered by filter and stride you specified. For stride 1, this
will ensure that output image size is same as input. For stride of 2,
output dimensions will be half, for example.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements.
Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias)
def forward(self, x):
img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride
output_h = math.ceil(img_h / stride_h)
output_w = math.ceil(img_w / stride_w)
pad_h = (
max((output_h - 1) * self.stride[0] +
(kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
pad_w = (
max((output_w - 1) * self.stride[1] +
(kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
])
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
import torch
from mmcv.cnn.bricks import Conv2dAdaptivePadding
def test_conv2d_samepadding():
# test Conv2dAdaptivePadding with stride=1
inputs = torch.rand((1, 3, 28, 28))
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=1)
output = conv(inputs)
assert output.shape == inputs.shape
inputs = torch.rand((1, 3, 13, 13))
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=1)
output = conv(inputs)
assert output.shape == inputs.shape
# test Conv2dAdaptivePadding with stride=2
inputs = torch.rand((1, 3, 28, 28))
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=2)
output = conv(inputs)
assert output.shape == torch.Size([1, 3, 14, 14])
inputs = torch.rand((1, 3, 13, 13))
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=2)
output = conv(inputs)
assert output.shape == torch.Size([1, 3, 7, 7])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment