efficient_ffn.py 3.67 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from mmcv.cnn import (build_activation_layer, build_conv_layer,
                      build_norm_layer, xavier_init)
from mmcv.cnn.bricks.registry import (FEEDFORWARD_NETWORK, TRANSFORMER_LAYER,
                                      TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
                                         TransformerLayerSequence,
                                         build_transformer_layer_sequence)
from mmcv.runner import force_fp32
from mmcv.runner.base_module import BaseModule
from mmcv.utils import deprecated_api_warning, to_2tuple
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_


@FEEDFORWARD_NETWORK.register_module()
class EfficientFFN(BaseModule):

    @deprecated_api_warning(
        {
            'dropout': 'ffn_drop',
            'add_residual': 'add_identity'
        },
        cls_name='EfficientFFN')
    def __init__(self,
                 embed_dims=256,
                 feedforward_channels=1024,
                 num_fcs=2,
                 act_cfg=dict(type='ReLU', inplace=True),
                 ffn_drop=0.,
                 dropout_layer=None,
                 add_identity=True,
                 init_cfg=None,
                 split=4,
                 use_checkpoint=False,
                 **kwargs):
        super(EfficientFFN, self).__init__(init_cfg)
        assert num_fcs >= 2, 'num_fcs should be no less ' \
                             f'than 2. got {num_fcs}.'
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.act_cfg = act_cfg
        self.activate = build_activation_layer(act_cfg)
        self.drop = nn.Dropout(ffn_drop)
        in_channels = embed_dims
        self.use_checkpoint = use_checkpoint
        self.split = split
        for i in range(split):
            fc1 = nn.Linear(in_channels, feedforward_channels //
                            self.split, bias=True)
            setattr(self, f'fc1_{i}', fc1)

        for i in range(split):
            fc2 = nn.Linear(feedforward_channels // self.split,
                            embed_dims, bias=False)
            setattr(self, f'fc2_{i}', fc2)
        self.fc2_bias = nn.Parameter(torch.zeros(
            (embed_dims)), requires_grad=True)
        # fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.fc2_0.weight)
        # bound = 1 / math.sqrt(fan_in)
        # torch.nn.init.uniform_(self.fc2_bias, -bound, bound)

        self.dropout_layer = build_dropout(
            dropout_layer) if dropout_layer else torch.nn.Identity()
        self.add_identity = add_identity

    @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
    def forward(self, x, identity=None):

        def _inner_forward(x, i):
            fc1 = getattr(self, f'fc1_{i}')
            x = fc1(x)
            x = self.activate(x)
            x = self.drop(x)
            fc2 = getattr(self, f'fc2_{i}')
            x = fc2(x)
            x = self.drop(x)
            return x

        out = 0
        for i in range(self.split):
            if self.use_checkpoint and x.requires_grad:
                out = out + checkpoint.checkpoint(_inner_forward, x, i)
            else:
                out = out + _inner_forward(x, i)

        out = out + self.fc2_bias

        if not self.add_identity:
            return self.dropout_layer(out)
        if identity is None:
            identity = x
        return identity + self.dropout_layer(out)