dcnv3.py 13.1 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
2
3
4
5
6
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

zhe chen's avatar
zhe chen committed
7
from __future__ import absolute_import, division, print_function
zhe chen's avatar
zhe chen committed
8
9

import warnings
zhe chen's avatar
zhe chen committed
10

11
import torch
zhe chen's avatar
zhe chen committed
12
import torch.nn.functional as F
zhe chen's avatar
zhe chen committed
13
14
15
from torch import nn
from torch.nn.init import constant_, xavier_uniform_

zhe chen's avatar
zhe chen committed
16
from ..functions import DCNv3Function, dcnv3_core_pytorch
zhe chen's avatar
zhe chen committed
17

18
19
20
21
22
try:
    from DCNv4.functions import DCNv4Function
except:
    warnings.warn('Now, we support DCNv4 in InternImage.')
import math
zhe chen's avatar
zhe chen committed
23

zhe chen's avatar
zhe chen committed
24

zhe chen's avatar
zhe chen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class to_channels_first(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 3, 1, 2)


class to_channels_last(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 2, 3, 1)


def build_norm_layer(dim,
                     norm_layer,
                     in_format='channels_last',
                     out_format='channels_last',
                     eps=1e-6):
    layers = []
    if norm_layer == 'BN':
        if in_format == 'channels_last':
            layers.append(to_channels_first())
        layers.append(nn.BatchNorm2d(dim))
        if out_format == 'channels_last':
            layers.append(to_channels_last())
    elif norm_layer == 'LN':
        if in_format == 'channels_first':
            layers.append(to_channels_last())
        layers.append(nn.LayerNorm(dim, eps=eps))
        if out_format == 'channels_first':
            layers.append(to_channels_first())
    else:
        raise NotImplementedError(
            f'build_norm_layer does not support {norm_layer}')
    return nn.Sequential(*layers)


def build_act_layer(act_layer):
    if act_layer == 'ReLU':
        return nn.ReLU(inplace=True)
    elif act_layer == 'SiLU':
        return nn.SiLU(inplace=True)
    elif act_layer == 'GELU':
        return nn.GELU()

    raise NotImplementedError(f'build_act_layer does not support {act_layer}')


def _is_power_of_2(n):
    if (not isinstance(n, int)) or (n < 0):
        raise ValueError(
zhe chen's avatar
zhe chen committed
81
            'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
82

83
    return (n & (n - 1) == 0) and n != 0
zhe chen's avatar
zhe chen committed
84

85
86
87
88
89
90
91
92
93
94

class CenterFeatureScaleModule(nn.Module):
    def forward(self,
                query,
                center_feature_scale_proj_weight,
                center_feature_scale_proj_bias):
        center_feature_scale = F.linear(query,
                                        weight=center_feature_scale_proj_weight,
                                        bias=center_feature_scale_proj_bias).sigmoid()
        return center_feature_scale
zhe chen's avatar
zhe chen committed
95
96
97
98


class DCNv3_pytorch(nn.Module):
    def __init__(
99
100
101
102
103
104
105
106
107
108
109
110
            self,
            channels=64,
            kernel_size=3,
            dw_kernel_size=None,
            stride=1,
            pad=1,
            dilation=1,
            group=4,
            offset_scale=1.0,
            act_layer='GELU',
            norm_layer='LN',
            center_feature_scale=False):
zhe chen's avatar
zhe chen committed
111
112
        """
        DCNv3 Module
113
114
115
116
        :param channels
        :param kernel_size
        :param stride
        :param pad
zhe chen's avatar
zhe chen committed
117
118
119
120
121
122
123
124
125
126
127
        :param dilation
        :param group
        :param offset_scale
        :param act_layer
        :param norm_layer
        """
        super().__init__()
        if channels % group != 0:
            raise ValueError(
                f'channels must be divisible by group, but got {channels} and {group}')
        _d_per_group = channels // group
128
        dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
zhe chen's avatar
zhe chen committed
129
130
131
132
        # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_group):
            warnings.warn(
                "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
zhe chen's avatar
zhe chen committed
133
                'which is more efficient in our CUDA implementation.')
134

zhe chen's avatar
zhe chen committed
135
136
137
        self.offset_scale = offset_scale
        self.channels = channels
        self.kernel_size = kernel_size
138
        self.dw_kernel_size = dw_kernel_size
zhe chen's avatar
zhe chen committed
139
        self.stride = stride
140
        self.dilation = dilation
zhe chen's avatar
zhe chen committed
141
142
143
144
        self.pad = pad
        self.group = group
        self.group_channels = channels // group
        self.offset_scale = offset_scale
145
        self.center_feature_scale = center_feature_scale
146

zhe chen's avatar
zhe chen committed
147
148
149
150
        self.dw_conv = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
151
                kernel_size=dw_kernel_size,
zhe chen's avatar
zhe chen committed
152
                stride=1,
153
                padding=(dw_kernel_size - 1) // 2,
zhe chen's avatar
zhe chen committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                groups=channels),
            build_norm_layer(
                channels,
                norm_layer,
                'channels_first',
                'channels_last'),
            build_act_layer(act_layer))
        self.offset = nn.Linear(
            channels,
            group * kernel_size * kernel_size * 2)
        self.mask = nn.Linear(
            channels,
            group * kernel_size * kernel_size)
        self.input_proj = nn.Linear(channels, channels)
        self.output_proj = nn.Linear(channels, channels)
        self._reset_parameters()
zhe chen's avatar
zhe chen committed
170

171
172
173
174
175
176
        if center_feature_scale:
            self.center_feature_scale_proj_weight = nn.Parameter(
                torch.zeros((group, channels), dtype=torch.float))
            self.center_feature_scale_proj_bias = nn.Parameter(
                torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
            self.center_feature_scale_module = CenterFeatureScaleModule()
177

zhe chen's avatar
zhe chen committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def _reset_parameters(self):
        constant_(self.offset.weight.data, 0.)
        constant_(self.offset.bias.data, 0.)
        constant_(self.mask.weight.data, 0.)
        constant_(self.mask.bias.data, 0.)
        xavier_uniform_(self.input_proj.weight.data)
        constant_(self.input_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

    def forward(self, input):
        """
        :param query                       (N, H, W, C)
        :return output                     (N, H, W, C)
        """
        N, H, W, _ = input.shape

        x = self.input_proj(input)
196
        x_proj = x
zhe chen's avatar
zhe chen committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        x1 = input.permute(0, 3, 1, 2)
        x1 = self.dw_conv(x1)
        offset = self.offset(x1)
        mask = self.mask(x1).reshape(N, H, W, self.group, -1)
        mask = F.softmax(mask, -1).reshape(N, H, W, -1)

        x = dcnv3_core_pytorch(
            x, offset, mask,
            self.kernel_size, self.kernel_size,
            self.stride, self.stride,
            self.pad, self.pad,
            self.dilation, self.dilation,
            self.group, self.group_channels,
            self.offset_scale)
212
213
214
215
216
217
218
        if self.center_feature_scale:
            center_feature_scale = self.center_feature_scale_module(
                x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
            # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
            center_feature_scale = center_feature_scale[..., None].repeat(
                1, 1, 1, 1, self.channels // self.group).flatten(-2)
            x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
zhe chen's avatar
zhe chen committed
219
        x = self.output_proj(x)
220

zhe chen's avatar
zhe chen committed
221
222
223
224
225
        return x


class DCNv3(nn.Module):
    def __init__(
226
227
228
229
230
231
232
233
234
235
236
            self,
            channels=64,
            kernel_size=3,
            dw_kernel_size=None,
            stride=1,
            pad=1,
            dilation=1,
            group=4,
            offset_scale=1.0,
            act_layer='GELU',
            norm_layer='LN',
237
238
            center_feature_scale=False,
            use_dcn_v4_op=False,
zhe chen's avatar
zhe chen committed
239
    ):
zhe chen's avatar
zhe chen committed
240
241
        """
        DCNv3 Module
242
243
244
245
        :param channels
        :param kernel_size
        :param stride
        :param pad
zhe chen's avatar
zhe chen committed
246
247
248
249
250
251
252
253
254
255
256
        :param dilation
        :param group
        :param offset_scale
        :param act_layer
        :param norm_layer
        """
        super().__init__()
        if channels % group != 0:
            raise ValueError(
                f'channels must be divisible by group, but got {channels} and {group}')
        _d_per_group = channels // group
257
        dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
zhe chen's avatar
zhe chen committed
258
259
260
261
        # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_group):
            warnings.warn(
                "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
zhe chen's avatar
zhe chen committed
262
                'which is more efficient in our CUDA implementation.')
263

zhe chen's avatar
zhe chen committed
264
265
266
        self.offset_scale = offset_scale
        self.channels = channels
        self.kernel_size = kernel_size
267
        self.dw_kernel_size = dw_kernel_size
zhe chen's avatar
zhe chen committed
268
        self.stride = stride
269
        self.dilation = dilation
zhe chen's avatar
zhe chen committed
270
271
272
273
        self.pad = pad
        self.group = group
        self.group_channels = channels // group
        self.offset_scale = offset_scale
274
        self.center_feature_scale = center_feature_scale
zhe chen's avatar
zhe chen committed
275

276
277
        self.use_dcn_v4_op = use_dcn_v4_op

zhe chen's avatar
zhe chen committed
278
279
280
281
        self.dw_conv = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
282
                kernel_size=dw_kernel_size,
zhe chen's avatar
zhe chen committed
283
                stride=1,
284
                padding=(dw_kernel_size - 1) // 2,
zhe chen's avatar
zhe chen committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                groups=channels),
            build_norm_layer(
                channels,
                norm_layer,
                'channels_first',
                'channels_last'),
            build_act_layer(act_layer))
        self.offset = nn.Linear(
            channels,
            group * kernel_size * kernel_size * 2)
        self.mask = nn.Linear(
            channels,
            group * kernel_size * kernel_size)
        self.input_proj = nn.Linear(channels, channels)
        self.output_proj = nn.Linear(channels, channels)
        self._reset_parameters()
zhe chen's avatar
zhe chen committed
301

302
303
304
305
306
307
        if center_feature_scale:
            self.center_feature_scale_proj_weight = nn.Parameter(
                torch.zeros((group, channels), dtype=torch.float))
            self.center_feature_scale_proj_bias = nn.Parameter(
                torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
            self.center_feature_scale_module = CenterFeatureScaleModule()
308

zhe chen's avatar
zhe chen committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    def _reset_parameters(self):
        constant_(self.offset.weight.data, 0.)
        constant_(self.offset.bias.data, 0.)
        constant_(self.mask.weight.data, 0.)
        constant_(self.mask.bias.data, 0.)
        xavier_uniform_(self.input_proj.weight.data)
        constant_(self.input_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

    def forward(self, input):
        """
        :param query                       (N, H, W, C)
        :return output                     (N, H, W, C)
        """
        N, H, W, _ = input.shape

        x = self.input_proj(input)
327
        x_proj = x
zhe chen's avatar
zhe chen committed
328
329
330
331
332
333
        dtype = x.dtype

        x1 = input.permute(0, 3, 1, 2)
        x1 = self.dw_conv(x1)
        offset = self.offset(x1)
        mask = self.mask(x1).reshape(N, H, W, self.group, -1)
zhe chen's avatar
zhe chen committed
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        if not self.use_dcn_v4_op:
            mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
            x = DCNv3Function.apply(
                x, offset, mask,
                self.kernel_size, self.kernel_size,
                self.stride, self.stride,
                self.pad, self.pad,
                self.dilation, self.dilation,
                self.group, self.group_channels,
                self.offset_scale,
                256)
        else:
            # DCNv4 combines offset and weight mask into one tensor `offset_mask`.
            # The following code is to align DCNv3 and DCNv4
            offset = offset.view(N, H, W, self.group, -1)
            mask = F.softmax(mask, -1)
            mask = mask.view(N, H, W, self.group, -1)
            offset_mask = torch.cat([offset, mask], -1).view(N, H, W, -1).contiguous()

zhe chen's avatar
zhe chen committed
354
            # For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
355
            K3 = offset_mask.size(-1)
zhe chen's avatar
zhe chen committed
356
            K3_pad = int(math.ceil(K3 / 8) * 8)
357
358
            pad_dim = K3_pad - K3
            offset_mask = torch.cat([offset_mask, offset_mask.new_zeros([*offset_mask.size()[:3], pad_dim])], -1)
zhe chen's avatar
zhe chen committed
359

360
361
362
363
364
365
366
367
368
369
370
371
            x = DCNv4Function.apply(
                x, offset_mask,
                self.kernel_size, self.kernel_size,
                self.stride, self.stride,
                self.pad, self.pad,
                self.dilation, self.dilation,
                self.group, self.group_channels,
                self.offset_scale,
                256,
                False
            )

372
373
374
375
376
377
378
        if self.center_feature_scale:
            center_feature_scale = self.center_feature_scale_module(
                x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
            # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
            center_feature_scale = center_feature_scale[..., None].repeat(
                1, 1, 1, 1, self.channels // self.group).flatten(-2)
            x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
zhe chen's avatar
zhe chen committed
379
        x = self.output_proj(x)
380

zhe chen's avatar
zhe chen committed
381
        return x