inverted_residual.py 4.07 KB
Newer Older
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
unknown's avatar
unknown committed
3
4
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
5
from mmcv.cnn.bricks import DropPath
unknown's avatar
unknown committed
6
7
8
9
10
11
12
13
14
from mmcv.runner import BaseModule

from .se_layer import SELayer


class InvertedResidual(BaseModule):
    """Inverted Residual Block.

    Args:
15
16
        in_channels (int): The input channels of this module.
        out_channels (int): The output channels of this module.
unknown's avatar
unknown committed
17
        mid_channels (int): The input channels of the depthwise convolution.
18
19
20
21
22
23
        kernel_size (int): The kernel size of the depthwise convolution.
            Defaults to 3.
        stride (int): The stride of the depthwise convolution. Defaults to 1.
        se_cfg (dict, optional): Config dict for se layer. Defaults to None,
            which means no se layer.
        conv_cfg (dict): Config dict for convolution layer. Defaults to None,
unknown's avatar
unknown committed
24
25
            which means using conv2d.
        norm_cfg (dict): Config dict for normalization layer.
26
            Defaults to ``dict(type='BN')``.
unknown's avatar
unknown committed
27
        act_cfg (dict): Config dict for activation layer.
28
29
            Defaults to ``dict(type='ReLU')``.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
unknown's avatar
unknown committed
30
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
31
32
            memory while slowing down the training speed. Defaults to False.
        init_cfg (dict | list[dict], optional): Initialization config dict.
unknown's avatar
unknown committed
33
34
35
36
37
38
39
40
41
42
43
44
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels,
                 kernel_size=3,
                 stride=1,
                 se_cfg=None,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
45
                 drop_path_rate=0.,
unknown's avatar
unknown committed
46
47
48
49
50
51
                 with_cp=False,
                 init_cfg=None):
        super(InvertedResidual, self).__init__(init_cfg)
        self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
        assert stride in [1, 2]
        self.with_cp = with_cp
52
53
        self.drop_path = DropPath(
            drop_path_rate) if drop_path_rate > 0 else nn.Identity()
unknown's avatar
unknown committed
54
        self.with_se = se_cfg is not None
55
        self.with_expand_conv = (mid_channels != in_channels)
unknown's avatar
unknown committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

        if self.with_se:
            assert isinstance(se_cfg, dict)

        if self.with_expand_conv:
            self.expand_conv = ConvModule(
                in_channels=in_channels,
                out_channels=mid_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg)
        self.depthwise_conv = ConvModule(
            in_channels=mid_channels,
            out_channels=mid_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            groups=mid_channels,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        if self.with_se:
            self.se = SELayer(**se_cfg)
        self.linear_conv = ConvModule(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
90
            act_cfg=None)
unknown's avatar
unknown committed
91
92

    def forward(self, x):
93
94
95
96
97
98
99
100
        """Forward function.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
unknown's avatar
unknown committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        def _inner_forward(x):
            out = x

            if self.with_expand_conv:
                out = self.expand_conv(out)

            out = self.depthwise_conv(out)

            if self.with_se:
                out = self.se(out)

            out = self.linear_conv(out)

            if self.with_res_shortcut:
116
                return x + self.drop_path(out)
unknown's avatar
unknown committed
117
118
119
120
121
122
123
124
125
            else:
                return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        return out