import math import torch import torch.nn as nn from mmcv.cnn import uniform_init from torch.nn.modules.module import Module from torch.nn.modules.utils import _pair from ..functions.deform_conv import deform_conv class DeformConv(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, num_deformable_groups=1): super(DeformConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.num_deformable_groups = num_deformable_groups self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, *self.kernel_size)) self.reset_parameters() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) uniform_init(self, -stdv, stdv) def forward(self, input, offset): return deform_conv(input, offset, self.weight, self.stride, self.padding, self.dilation, self.num_deformable_groups)