"googlemock/test/gmock-actions_test.cc" did not exist on "62a35fbc5d316c4f82f37cb6e183eb8de210a94c"
cbnet_channel_mapper.py 711 Bytes
Newer Older
zhe chen's avatar
zhe chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from mmdet.models.builder import NECKS
from mmdet.models.necks import ChannelMapper


@NECKS.register_module()
class CBChannelMapper(ChannelMapper):

    def __init__(self, cb_idx=1, **kwargs):
        super(CBChannelMapper, self).__init__(**kwargs)
        self.cb_idx = cb_idx

    def forward(self, inputs):
        if not isinstance(inputs[0], (list, tuple)):
            inputs = [inputs]

        if self.training:
            outs = []
            # from IPython import embed; embed()
            for x in inputs:
                out = super().forward(x)
                outs.append(out)
            return outs
        else:
            out = super().forward(inputs[self.cb_idx])
            return out