"...pipelines/animatediff/pipeline_animatediff_sdxl.py" did not exist on "f782ca112a30b5e022a74ae029d47f4c62f7fca4"
resnext.py 7.64 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
import math

import torch.nn as nn

yhcao6's avatar
yhcao6 committed
5
from mmdet.ops import DeformConv, ModulatedDeformConv
pangjm's avatar
pangjm committed
6
from .resnet import Bottleneck as _Bottleneck
yhcao6's avatar
yhcao6 committed
7
from .resnet import ResNet
Kai Chen's avatar
Kai Chen committed
8
from ..registry import BACKBONES
9
from ..utils import build_conv_layer, build_norm_layer
pangjm's avatar
pangjm committed
10
11


pangjm's avatar
pangjm committed
12
class Bottleneck(_Bottleneck):
pangjm's avatar
pangjm committed
13

14
    def __init__(self, groups=1, base_width=4, *args, **kwargs):
pangjm's avatar
pangjm committed
15
        """Bottleneck block for ResNeXt.
pangjm's avatar
pangjm committed
16
17
18
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
pangjm's avatar
pangjm committed
19
        super(Bottleneck, self).__init__(*args, **kwargs)
pangjm's avatar
pangjm committed
20

pangjm's avatar
pangjm committed
21
        if groups == 1:
pangjm's avatar
pangjm committed
22
            width = self.planes
pangjm's avatar
pangjm committed
23
        else:
pangjm's avatar
pangjm committed
24
            width = math.floor(self.planes * (base_width / 64)) * groups
pangjm's avatar
pangjm committed
25

yhcao6's avatar
yhcao6 committed
26
        self.norm1_name, norm1 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
27
            self.norm_cfg, width, postfix=1)
yhcao6's avatar
yhcao6 committed
28
        self.norm2_name, norm2 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
29
            self.norm_cfg, width, postfix=2)
yhcao6's avatar
yhcao6 committed
30
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
31
            self.norm_cfg, self.planes * self.expansion, postfix=3)
ThangVu's avatar
ThangVu committed
32

33
34
        self.conv1 = build_conv_layer(
            self.conv_cfg,
pangjm's avatar
pangjm committed
35
36
37
38
39
            self.inplanes,
            width,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
40
        self.add_module(self.norm1_name, norm1)
yhcao6's avatar
yhcao6 committed
41
42
43
44
45
46
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = self.dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = self.dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
47
48
            self.conv2 = build_conv_layer(
                self.conv_cfg,
yhcao6's avatar
yhcao6 committed
49
50
51
52
53
54
55
56
57
                width,
                width,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=self.dilation,
                dilation=self.dilation,
                groups=groups,
                bias=False)
        else:
58
            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
59
            groups = self.dcn.get('groups', 1)
yhcao6's avatar
yhcao6 committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            deformable_groups = self.dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                width,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=self.dilation,
                dilation=self.dilation)
            self.conv2 = conv_op(
                width,
                width,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=self.dilation,
                dilation=self.dilation,
                groups=groups,
                deformable_groups=deformable_groups,
                bias=False)
84
        self.add_module(self.norm2_name, norm2)
85
86
87
88
89
90
        self.conv3 = build_conv_layer(
            self.conv_cfg,
            width,
            self.planes * self.expansion,
            kernel_size=1,
            bias=False)
91
        self.add_module(self.norm3_name, norm3)
pangjm's avatar
pangjm committed
92
93
94
95
96
97
98
99
100
101
102


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
                   groups=1,
                   base_width=4,
                   style='pytorch',
ThangVu's avatar
ThangVu committed
103
                   with_cp=False,
104
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
105
                   norm_cfg=dict(type='BN'),
106
107
                   dcn=None,
                   gcb=None):
pangjm's avatar
pangjm committed
108
109
110
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
111
112
            build_conv_layer(
                conv_cfg,
pangjm's avatar
pangjm committed
113
114
115
116
117
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
118
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
pangjm's avatar
pangjm committed
119
120
121
122
123
        )

    layers = []
    layers.append(
        block(
Kai Chen's avatar
Kai Chen committed
124
125
            inplanes=inplanes,
            planes=planes,
pangjm's avatar
pangjm committed
126
127
128
            stride=stride,
            dilation=dilation,
            downsample=downsample,
pangjm's avatar
pangjm committed
129
130
131
            groups=groups,
            base_width=base_width,
            style=style,
ThangVu's avatar
ThangVu committed
132
            with_cp=with_cp,
133
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
134
            norm_cfg=norm_cfg,
135
136
            dcn=dcn,
            gcb=gcb))
pangjm's avatar
pangjm committed
137
138
139
140
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
            block(
Kai Chen's avatar
Kai Chen committed
141
142
                inplanes=inplanes,
                planes=planes,
pangjm's avatar
pangjm committed
143
144
                stride=1,
                dilation=dilation,
pangjm's avatar
pangjm committed
145
146
147
                groups=groups,
                base_width=base_width,
                style=style,
ThangVu's avatar
ThangVu committed
148
                with_cp=with_cp,
149
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
150
                norm_cfg=norm_cfg,
151
152
                dcn=dcn,
                gcb=gcb))
pangjm's avatar
pangjm committed
153
154
155
156

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
157
@BACKBONES.register_module
pangjm's avatar
pangjm committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class ResNeXt(ResNet):
    """ResNeXt backbone.

    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        groups (int): Group of resnext.
        base_width (int): Base width of resnext.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
174
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
175
176
177
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
pangjm's avatar
pangjm committed
178
179
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
180
181
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
pangjm's avatar
pangjm committed
182
183
184
185
186
187
188
189
    """

    arch_settings = {
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }

pangjm's avatar
pangjm committed
190
191
    def __init__(self, groups=1, base_width=4, **kwargs):
        super(ResNeXt, self).__init__(**kwargs)
pangjm's avatar
pangjm committed
192
193
194
195
196
197
        self.groups = groups
        self.base_width = base_width

        self.inplanes = 64
        self.res_layers = []
        for i, num_blocks in enumerate(self.stage_blocks):
pangjm's avatar
pangjm committed
198
199
            stride = self.strides[i]
            dilation = self.dilations[i]
yhcao6's avatar
yhcao6 committed
200
            dcn = self.dcn if self.stage_with_dcn[i] else None
201
            gcb = self.gcb if self.stage_with_gcb[i] else None
pangjm's avatar
pangjm committed
202
203
204
205
206
207
208
209
210
211
212
            planes = 64 * 2**i
            res_layer = make_res_layer(
                self.block,
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                groups=self.groups,
                base_width=self.base_width,
                style=self.style,
ThangVu's avatar
ThangVu committed
213
                with_cp=self.with_cp,
214
                conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
215
                norm_cfg=self.norm_cfg,
216
217
                dcn=dcn,
                gcb=gcb)
pangjm's avatar
pangjm committed
218
219
220
221
            self.inplanes = planes * self.block.expansion
            layer_name = 'layer{}'.format(i + 1)
            self.add_module(layer_name, res_layer)
            self.res_layers.append(layer_name)
ThangVu's avatar
ThangVu committed
222
223

        self._freeze_stages()