mlp.py 2.11 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Tuple

4
from mmcv.cnn import ConvModule
5
from mmengine.model import BaseModule
6
from torch import Tensor
7
8
from torch import nn as nn

9
10
from mmdet3d.utils import ConfigType, OptMultiConfig

11

12
class MLP(BaseModule):
13
14
15
16
17
    """A simple MLP module.

    Pass features (B, C, N) through an MLP.

    Args:
18
19
20
21
22
23
24
25
26
27
28
29
        in_channels (int): Number of channels of input features.
            Defaults to 18.
        conv_channels (Tuple[int]): Out channels of the convolution.
            Defaults to (256, 256).
        conv_cfg (:obj:`ConfigDict` or dict): Config dict for convolution
            layer. Defaults to dict(type='Conv1d').
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN1d').
        act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
            Defaults to dict(type='ReLU').
        init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
            optional): Initialization config dict. Defaults to None.
30
31
32
    """

    def __init__(self,
33
34
35
36
37
38
39
                 in_channel: int = 18,
                 conv_channels: Tuple[int] = (256, 256),
                 conv_cfg: ConfigType = dict(type='Conv1d'),
                 norm_cfg: ConfigType = dict(type='BN1d'),
                 act_cfg: ConfigType = dict(type='ReLU'),
                 init_cfg: OptMultiConfig = None) -> None:
        super(MLP, self).__init__(init_cfg=init_cfg)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.mlp = nn.Sequential()
        prev_channels = in_channel
        for i, conv_channel in enumerate(conv_channels):
            self.mlp.add_module(
                f'layer{i}',
                ConvModule(
                    prev_channels,
                    conv_channels[i],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    bias=True,
                    inplace=True))
            prev_channels = conv_channels[i]

57
    def forward(self, img_features: Tensor) -> Tensor:
58
        return self.mlp(img_features)