deform_conv.py 6.83 KB
Newer Older
1
import math
2
from typing import Optional, Tuple
3
4
5
6
7

import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.modules.utils import _pair
8
from torch.nn.parameter import Parameter
9
from torchvision.extension import _assert_has_ops
10

11
12
from ..utils import _log_api_usage_once

13

14
15
16
17
18
19
20
21
def deform_conv2d(
    input: Tensor,
    offset: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (0, 0),
    dilation: Tuple[int, int] = (1, 1),
22
    mask: Optional[Tensor] = None,
23
) -> Tensor:
24
25
26
27
28
29
30
    r"""
    Performs Deformable Convolution v2, described in
    `Deformable ConvNets v2: More Deformable, Better Results
    <https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and
    Performs Deformable Convolution, described in
    `Deformable Convolutional Networks
    <https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``.
31

32
    Args:
33
        input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
34
35
36
37
        offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
            offsets to be applied for each position in the convolution kernel.
        weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,
            split into groups of size (in_channels // groups)
38
39
40
41
42
        bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
        stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
        padding (int or Tuple[int, int]): height/width of padding of zeroes around
            each image. Default: 0
        dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
43
44
        mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
            masks to be applied for each position in the convolution kernel. Default: None
45
46

    Returns:
47
        Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution
48
49

    Examples::
vfdev's avatar
vfdev committed
50
        >>> input = torch.rand(4, 3, 10, 10)
51
52
        >>> kh, kw = 3, 3
        >>> weight = torch.rand(5, 3, kh, kw)
53
        >>> # offset and mask should have the same spatial size as the output
54
55
        >>> # of the convolution. In this case, for an input of 10, stride of 1
        >>> # and kernel size of 3, without padding, the output size is 8
vfdev's avatar
vfdev committed
56
        >>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
57
58
        >>> mask = torch.rand(4, kh * kw, 8, 8)
        >>> out = deform_conv2d(input, offset, weight, mask=mask)
59
60
        >>> print(out.shape)
        >>> # returns
vfdev's avatar
vfdev committed
61
        >>>  torch.Size([4, 5, 8, 8])
62
    """
Kai Zhang's avatar
Kai Zhang committed
63
64
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(deform_conv2d)
65
    _assert_has_ops()
66
    out_channels = weight.shape[0]
67
68
69
70
71
72

    use_mask = mask is not None

    if mask is None:
        mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)

73
74
75
76
77
78
79
    if bias is None:
        bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)

    stride_h, stride_w = _pair(stride)
    pad_h, pad_w = _pair(padding)
    dil_h, dil_w = _pair(dilation)
    weights_h, weights_w = weight.shape[-2:]
80
    _, n_in_channels, _, _ = input.shape
81
82
83
84

    n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
    n_weight_grps = n_in_channels // weight.shape[1]

85
86
87
88
    if n_offset_grps == 0:
        raise RuntimeError(
            "the shape of the offset tensor at dimension 1 is not valid. It should "
            "be a multiple of 2 * weight.size[2] * weight.size[3].\n"
89
            f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}"
90
        )
91

92
93
94
95
    return torch.ops.torchvision.deform_conv2d(
        input,
        weight,
        offset,
96
        mask,
97
        bias,
98
99
100
101
102
103
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dil_h,
        dil_w,
104
        n_weight_grps,
105
        n_offset_grps,
106
107
        use_mask,
    )
108
109
110
111


class DeformConv2d(nn.Module):
    """
112
    See :func:`deform_conv2d`.
113
    """
114

115
116
117
118
119
120
121
122
123
124
125
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
126
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
127
        _log_api_usage_once(self)
128
129

        if in_channels % groups != 0:
130
            raise ValueError("in_channels must be divisible by groups")
131
        if out_channels % groups != 0:
132
            raise ValueError("out_channels must be divisible by groups")
133
134
135
136
137
138
139
140
141

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups

142
143
144
        self.weight = Parameter(
            torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])
        )
145
146
147
148

        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
149
            self.register_parameter("bias", None)
150
151
152

        self.reset_parameters()

153
    def reset_parameters(self) -> None:
154
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
155

156
157
158
159
160
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

161
    def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) -> Tensor:
162
        """
163
        Args:
164
            input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
165
166
167
168
            offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
                offsets to be applied for each position in the convolution kernel.
            mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
                masks to be applied for each position in the convolution kernel.
169
        """
170
171
172
173
174
175
176
177
178
179
        return deform_conv2d(
            input,
            offset,
            self.weight,
            self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            mask=mask,
        )
180

181
    def __repr__(self) -> str:
Joao Gomes's avatar
Joao Gomes committed
182
183
184
185
186
187
188
189
190
191
        s = (
            f"{self.__class__.__name__}("
            f"{self.in_channels}"
            f", {self.out_channels}"
            f", kernel_size={self.kernel_size}"
            f", stride={self.stride}"
        )
        s += f", padding={self.padding}" if self.padding != (0, 0) else ""
        s += f", dilation={self.dilation}" if self.dilation != (1, 1) else ""
        s += f", groups={self.groups}" if self.groups != 1 else ""
192
193
        s += ", bias=False" if self.bias is None else ""
        s += ")"
Joao Gomes's avatar
Joao Gomes committed
194
195

        return s