channel_mapper.py 2.61 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from functools import partial
from typing import List

from torch import nn
from models.bricks.misc import Conv2dNormActivation


class ChannelMapper(nn.Module):
    def __init__(
        self,
        in_channels: List[int],
        out_channels: int,
        num_outs: int,
        kernel_size: int = 1,
        stride: int = 1,
        groups: int = 1,
        norm_layer=partial(nn.GroupNorm, 32),
        activation_layer: nn.Module = None,
        dilation: int = 1,
        inplace: bool = True,
        bias: bool = None,
    ):
        self.in_channels = in_channels
        super().__init__()
        self.convs = nn.ModuleList()
        self.num_channels = [out_channels] * num_outs
        for in_channel in in_channels:
            self.convs.append(
                Conv2dNormActivation(
                    in_channels=in_channel,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=(kernel_size - 1) // 2,
                    bias=bias,
                    groups=groups,
                    dilation=dilation,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                    inplace=inplace,
                )
            )
        for _ in range(num_outs - len(in_channels)):
            self.convs.append(
                Conv2dNormActivation(
                    in_channels=in_channel,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias=bias,
                    groups=groups,
                    dilation=dilation,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                    inplace=inplace,
                )
            )
            in_channel = out_channels
        
        self.init_weights()
    
    def init_weights(self):
        # initialize modules
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight, gain=1)
                if layer.bias:
                    nn.init.constant_(layer.bias, 0)
    
    def forward(self, inputs):
        inputs = list(inputs.values())
        assert len(inputs) == len(self.in_channels)
        outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
        for i in range(len(inputs), len(self.convs)):
            if i == len(inputs):
                outs.append(self.convs[i](inputs[-1]))
            else:
                outs.append(self.convs[i](outs[-1]))
        return outs