You need to sign in or sign up before continuing.
Unverified Commit f959a34d authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Docs] Docstring for DeformConv (#921)

* [Docs] Docstring for DeformConv

* fix docstring

* fix according to comments

* revise according to comments

* lint
parent b5e1facc
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
......@@ -179,19 +181,41 @@ deform_conv2d = DeformConv2dFunction.apply
class DeformConv2d(nn.Module):
r"""Deformable 2D convolution.
Applies a deformable 2D convolution over an input signal composed of
several input planes. DeformConv2d was described in the paper
`Deformable Convolutional Networks
<https://arxiv.org/pdf/1703.06211.pdf>`_
Args:
in_channels (int): Number of channels in the input image.
out_channels (int): Number of channels produced by the convolution.
kernel_size(int, tuple): Size of the convolving kernel.
stride(int, tuple): Stride of the convolution. Default: 1.
padding (int or tuple): Zero-padding added to both sides of the input.
Default: 0.
dilation (int or tuple): Spacing between kernel elements. Default: 1.
groups (int): Number of blocked connections from input.
channels to output channels. Default: 1.
deform_groups (int): Number of deformable group partitions.
bias (bool): If True, adds a learnable bias to the output.
Default: True.
"""
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='DeformConv2d')
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1,
bias=False):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
deform_groups: int = 1,
bias: bool = False) -> None:
super(DeformConv2d, self).__init__()
assert not bias, \
......@@ -210,6 +234,7 @@ class DeformConv2d(nn.Module):
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
......@@ -228,7 +253,27 @@ class DeformConv2d(nn.Module):
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
def forward(self, x: Tensor, offset: Tensor) -> Tensor:
"""Deformable Convolutional forward function.
Args:
x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
offset (Tensor): Offset for deformable convolution, shape
(B, deform_groups*kernel_size[0]*kernel_size[1]*2,
H_out, W_out), H_out, W_out are equal to the output's.
An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
The spatial arrangement is like:
.. code:: text
(x0, y0) (x1, y1) (x2, y2)
(x3, y3) (x4, y4) (x5, y5)
(x6, y6) (x7, y7) (x8, y8)
Returns:
Tensor: Output of the layer.
"""
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
......@@ -246,6 +291,19 @@ class DeformConv2d(nn.Module):
pad_w].contiguous()
return out
def __repr__(self):
s = self.__class__.__name__
s += f'(in_channels={self.in_channels},\n'
s += f'out_channels={self.out_channels},\n'
s += f'kernel_size={self.kernel_size},\n'
s += f'stride={self.stride},\n'
s += f'padding={self.padding},\n'
s += f'dilation={self.dilation},\n'
s += f'groups={self.groups},\n'
s += f'deform_groups={self.deform_groups},\n'
s += f'bias={self.bias})'
return s
@CONV_LAYERS.register_module('DCN')
class DeformConv2dPack(DeformConv2d):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment