import math import torch from torch import nn, Tensor from torch.nn import init from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair from torch.jit.annotations import Optional, Tuple def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)): # type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor """ Performs Deformable Convolution, described in Deformable Convolutional Networks Arguments: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights, split into groups of size (in_channels // groups) bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None stride (int or Tuple[int, int]): distance between convolution centers. Default: 1 padding (int or Tuple[int, int]): height/width of padding of zeroes around each image. Default: 0 dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 Returns: output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution """ out_channels = weight.shape[0] if bias is None: bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) stride_h, stride_w = _pair(stride) pad_h, pad_w = _pair(padding) dil_h, dil_w = _pair(dilation) weights_h, weights_w = weight.shape[-2:] _, n_in_channels, in_h, in_w = input.shape n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) n_weight_grps = n_in_channels // weight.shape[1] return torch.ops.torchvision.deform_conv2d( input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, n_weight_grps, n_offset_grps) class DeformConv2d(nn.Module): """ See deform_conv2d """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, offset_groups=1, bias=True): super(DeformConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if in_channels % offset_groups != 0: raise ValueError('in_channels must be divisible by offset_groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups self.offset_groups = offset_groups self.weight = Parameter(torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input, offset): """ Arguments: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights, split into groups of size (in_channels // groups) offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. """ return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) def __repr__(self): s = self.__class__.__name__ + '(' s += '{in_channels}' s += ', {out_channels}' s += ', kernel_size={kernel_size}' s += ', stride={stride}' s += ', padding={padding}' if self.padding != (0, 0) else '' s += ', dilation={dilation}' if self.dilation != (1, 1) else '' s += ', groups={groups}' if self.groups != 1 else '' s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else '' s += ', bias=False' if self.bias is None else '' s += ')' return s.format(**self.__dict__)