deform_conv.py 6.64 KB
Newer Older
1
import math
2
from typing import Optional, Tuple
3
4
5
6
7

import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.modules.utils import _pair
8
from torch.nn.parameter import Parameter
9
from torchvision.extension import _assert_has_ops
10
11


12
13
14
15
16
17
18
19
def deform_conv2d(
    input: Tensor,
    offset: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (0, 0),
    dilation: Tuple[int, int] = (1, 1),
20
    mask: Optional[Tensor] = None,
21
) -> Tensor:
22
23
24
25
26
27
28
    r"""
    Performs Deformable Convolution v2, described in
    `Deformable ConvNets v2: More Deformable, Better Results
    <https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and
    Performs Deformable Convolution, described in
    `Deformable Convolutional Networks
    <https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``.
29

30
    Args:
31
        input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
32
33
34
35
        offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
            offsets to be applied for each position in the convolution kernel.
        weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,
            split into groups of size (in_channels // groups)
36
37
38
39
40
        bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
        stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
        padding (int or Tuple[int, int]): height/width of padding of zeroes around
            each image. Default: 0
        dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
41
42
        mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
            masks to be applied for each position in the convolution kernel. Default: None
43
44

    Returns:
45
        Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution
46
47

    Examples::
vfdev's avatar
vfdev committed
48
        >>> input = torch.rand(4, 3, 10, 10)
49
50
        >>> kh, kw = 3, 3
        >>> weight = torch.rand(5, 3, kh, kw)
51
        >>> # offset and mask should have the same spatial size as the output
52
53
        >>> # of the convolution. In this case, for an input of 10, stride of 1
        >>> # and kernel size of 3, without padding, the output size is 8
vfdev's avatar
vfdev committed
54
        >>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
55
56
        >>> mask = torch.rand(4, kh * kw, 8, 8)
        >>> out = deform_conv2d(input, offset, weight, mask=mask)
57
58
        >>> print(out.shape)
        >>> # returns
vfdev's avatar
vfdev committed
59
        >>>  torch.Size([4, 5, 8, 8])
60
61
    """

62
    _assert_has_ops()
63
    out_channels = weight.shape[0]
64
65
66
67
68
69

    use_mask = mask is not None

    if mask is None:
        mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)

70
71
72
73
74
75
76
    if bias is None:
        bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)

    stride_h, stride_w = _pair(stride)
    pad_h, pad_w = _pair(padding)
    dil_h, dil_w = _pair(dilation)
    weights_h, weights_w = weight.shape[-2:]
77
    _, n_in_channels, _, _ = input.shape
78
79
80
81

    n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
    n_weight_grps = n_in_channels // weight.shape[1]

82
83
84
85
    if n_offset_grps == 0:
        raise RuntimeError(
            "the shape of the offset tensor at dimension 1 is not valid. It should "
            "be a multiple of 2 * weight.size[2] * weight.size[3].\n"
86
            f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}"
87
        )
88

89
90
91
92
    return torch.ops.torchvision.deform_conv2d(
        input,
        weight,
        offset,
93
        mask,
94
        bias,
95
96
97
98
99
100
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dil_h,
        dil_w,
101
        n_weight_grps,
102
        n_offset_grps,
103
104
        use_mask,
    )
105
106
107
108


class DeformConv2d(nn.Module):
    """
109
    See :func:`deform_conv2d`.
110
    """
111

112
113
114
115
116
117
118
119
120
121
122
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
123
        super().__init__()
124
125

        if in_channels % groups != 0:
126
            raise ValueError("in_channels must be divisible by groups")
127
        if out_channels % groups != 0:
128
            raise ValueError("out_channels must be divisible by groups")
129
130
131
132
133
134
135
136
137

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups

138
139
140
        self.weight = Parameter(
            torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])
        )
141
142
143
144

        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
145
            self.register_parameter("bias", None)
146
147
148

        self.reset_parameters()

149
    def reset_parameters(self) -> None:
150
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
151

152
153
154
155
156
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

157
    def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) -> Tensor:
158
        """
159
        Args:
160
161
162
163
            input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
            offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
                out_height, out_width]): offsets to be applied for each position in the
                convolution kernel.
164
165
166
            mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
                out_height, out_width]): masks to be applied for each position in the
                convolution kernel.
167
        """
168
169
170
171
172
173
174
175
176
177
        return deform_conv2d(
            input,
            offset,
            self.weight,
            self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            mask=mask,
        )
178

179
    def __repr__(self) -> str:
180
181
182
183
184
185
186
187
188
189
        s = self.__class__.__name__ + "("
        s += "{in_channels}"
        s += ", {out_channels}"
        s += ", kernel_size={kernel_size}"
        s += ", stride={stride}"
        s += ", padding={padding}" if self.padding != (0, 0) else ""
        s += ", dilation={dilation}" if self.dilation != (1, 1) else ""
        s += ", groups={groups}" if self.groups != 1 else ""
        s += ", bias=False" if self.bias is None else ""
        s += ")"
190
        return s.format(**self.__dict__)