mlp.py 1.59 KB
Newer Older
1
from mmcv.cnn import ConvModule
2
from mmcv.runner import BaseModule
3
4
5
from torch import nn as nn


6
class MLP(BaseModule):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """A simple MLP module.

    Pass features (B, C, N) through an MLP.

    Args:
        in_channels (int): Number of channels of input features.
            Default: 18.
        conv_channels (tuple[int]): Out channels of the convolution.
            Default: (256, 256).
        conv_cfg (dict): Config of convolution.
            Default: dict(type='Conv1d').
        norm_cfg (dict): Config of normalization.
            Default: dict(type='BN1d').
        act_cfg (dict): Config of activation.
            Default: dict(type='ReLU').
    """

    def __init__(self,
                 in_channel=18,
                 conv_channels=(256, 256),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
29
30
31
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        self.mlp = nn.Sequential()
        prev_channels = in_channel
        for i, conv_channel in enumerate(conv_channels):
            self.mlp.add_module(
                f'layer{i}',
                ConvModule(
                    prev_channels,
                    conv_channels[i],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    bias=True,
                    inplace=True))
            prev_channels = conv_channels[i]

    def forward(self, img_features):
        return self.mlp(img_features)