##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang ## Email: zhanghang0704@gmail.com ## Copyright (c) 2020 ## ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """Rectify Module""" import warnings import torch from torch.nn import Conv2d import torch.nn.functional as F from torch.nn.modules.utils import _pair from ..functions import rectify __all__ = ['RFConv2d'] class RFConv2d(Conv2d): """Rectified Convolution """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', average_mode=False): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) self.rectify = average_mode or (padding[0] > 0 or padding[1] > 0) self.average = average_mode super(RFConv2d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) def _conv_forward(self, input, weight): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), weight, self.bias, self.stride, _pair(0), self.dilation, self.groups) return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input): output = self._conv_forward(input, self.weight) if self.rectify: output = rectify(output, input, self.kernel_size, self.stride, self.padding, self.dilation, self.average) return output def extra_repr(self): return super().extra_repr() + ', rectify={}, average_mode={}'. \ format(self.rectify, self.average)