resnext.py 4.84 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
5
import math

import torch.nn as nn

from .resnet import ResNet
pangjm's avatar
pangjm committed
6
from .resnet import Bottleneck as _Bottleneck
Kai Chen's avatar
Kai Chen committed
7
from ..registry import BACKBONES
pangjm's avatar
pangjm committed
8
9


pangjm's avatar
pangjm committed
10
class Bottleneck(_Bottleneck):
pangjm's avatar
pangjm committed
11

pangjm's avatar
pangjm committed
12
    def __init__(self, *args, groups=1, base_width=4, **kwargs):
pangjm's avatar
pangjm committed
13
        """Bottleneck block for ResNeXt.
pangjm's avatar
pangjm committed
14
15
16
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
pangjm's avatar
pangjm committed
17
        super(Bottleneck, self).__init__(*args, **kwargs)
pangjm's avatar
pangjm committed
18

pangjm's avatar
pangjm committed
19
        if groups == 1:
pangjm's avatar
pangjm committed
20
            width = self.planes
pangjm's avatar
pangjm committed
21
        else:
pangjm's avatar
pangjm committed
22
            width = math.floor(self.planes * (base_width / 64)) * groups
pangjm's avatar
pangjm committed
23

pangjm's avatar
pangjm committed
24
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
25
26
27
28
29
            self.inplanes,
            width,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
pangjm's avatar
pangjm committed
30
31
32
33
34
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(
            width,
            width,
            kernel_size=3,
pangjm's avatar
pangjm committed
35
36
37
            stride=self.conv2_stride,
            padding=self.dilation,
            dilation=self.dilation,
pangjm's avatar
pangjm committed
38
39
40
41
            groups=groups,
            bias=False)
        self.bn2 = nn.BatchNorm2d(width)
        self.conv3 = nn.Conv2d(
pangjm's avatar
pangjm committed
42
43
            width, self.planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.planes * self.expansion)
pangjm's avatar
pangjm committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
                   groups=1,
                   base_width=4,
                   style='pytorch',
                   with_cp=False):
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
pangjm's avatar
pangjm committed
73
74
75
            stride=stride,
            dilation=dilation,
            downsample=downsample,
pangjm's avatar
pangjm committed
76
77
78
79
80
81
82
83
84
85
            groups=groups,
            base_width=base_width,
            style=style,
            with_cp=with_cp))
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
            block(
                inplanes,
                planes,
pangjm's avatar
pangjm committed
86
87
                stride=1,
                dilation=dilation,
pangjm's avatar
pangjm committed
88
89
90
91
92
93
94
95
                groups=groups,
                base_width=base_width,
                style=style,
                with_cp=with_cp))

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
96
@BACKBONES.register_module
pangjm's avatar
pangjm committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class ResNeXt(ResNet):
    """ResNeXt backbone.

    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        groups (int): Group of resnext.
        base_width (int): Base width of resnext.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
        bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
            running stats (mean and var).
        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """

    arch_settings = {
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }

pangjm's avatar
pangjm committed
126
127
    def __init__(self, groups=1, base_width=4, **kwargs):
        super(ResNeXt, self).__init__(**kwargs)
pangjm's avatar
pangjm committed
128
129
130
131
132
133
        self.groups = groups
        self.base_width = base_width

        self.inplanes = 64
        self.res_layers = []
        for i, num_blocks in enumerate(self.stage_blocks):
pangjm's avatar
pangjm committed
134
135
            stride = self.strides[i]
            dilation = self.dilations[i]
pangjm's avatar
pangjm committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            planes = 64 * 2**i
            res_layer = make_res_layer(
                self.block,
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                groups=self.groups,
                base_width=self.base_width,
                style=self.style,
                with_cp=self.with_cp)
            self.inplanes = planes * self.block.expansion
            layer_name = 'layer{}'.format(i + 1)
            self.add_module(layer_name, res_layer)
            self.res_layers.append(layer_name)