nostem_regnet.py 3.54 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
from typing import Tuple

import torch.nn as nn
zhangwenwei's avatar
Regnet  
zhangwenwei committed
5
from mmdet.models.backbones import RegNet
6
from torch import Tensor
zhangwenwei's avatar
Regnet  
zhangwenwei committed
7

8
from mmdet3d.registry import MODELS
9
from mmdet3d.utils import OptMultiConfig
10

zhangwenwei's avatar
Regnet  
zhangwenwei committed
11

12
@MODELS.register_module()
zhangwenwei's avatar
Regnet  
zhangwenwei committed
13
14
15
16
17
18
19
class NoStemRegNet(RegNet):
    """RegNet backbone without Stem for 3D detection.

    More details can be found in `paper <https://arxiv.org/abs/2003.13678>`_ .

    Args:
        arch (dict): The parameter of RegNets.
wangtai's avatar
wangtai committed
20
21
22
23
24
            - w0 (int): Initial width.
            - wa (float): Slope of width.
            - wm (float): Quantization parameter to quantize the width.
            - depth (int): Depth of the backbone.
            - group_w (int): Width of group.
25
            - bot_mul (float): Bottleneck ratio, i.e. expansion of bottleneck.
zhangwenwei's avatar
Regnet  
zhangwenwei committed
26
27
28
29
30
31
32
33
34
35
        strides (Sequence[int]): Strides of the first block of each stage.
        base_channels (int): Base channels after stem layer.
        in_channels (int): Number of input image channels. Normally 3.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
wangtai's avatar
wangtai committed
36
        norm_cfg (dict): Dictionary to construct and config norm layer.
zhangwenwei's avatar
Regnet  
zhangwenwei committed
37
38
39
40
41
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
wangtai's avatar
wangtai committed
42
        zero_init_residual (bool): Whether to use zero init for last norm layer
zhangwenwei's avatar
Regnet  
zhangwenwei committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
            in resblocks to let them behave as identity.

    Example:
        >>> from mmdet3d.models import NoStemRegNet
        >>> import torch
        >>> self = NoStemRegNet(
                arch=dict(
                    w0=88,
                    wa=26.31,
                    wm=2.25,
                    group_w=48,
                    depth=25,
                    bot_mul=1.0))
        >>> self.eval()
        >>> inputs = torch.rand(1, 64, 16, 16)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 96, 8, 8)
        (1, 192, 4, 4)
        (1, 432, 2, 2)
        (1, 1008, 1, 1)
    """

67
68
69
70
    def __init__(self,
                 arch: dict,
                 init_cfg: OptMultiConfig = None,
                 **kwargs) -> None:
71
        super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)
zhangwenwei's avatar
Regnet  
zhangwenwei committed
72

73
74
    def _make_stem_layer(self, in_channels: int,
                         base_channels: int) -> nn.Module:
75
76
        """Override the original function that do not initialize a stem layer
        since 3D detector's voxel encoder works like a stem layer."""
zhangwenwei's avatar
Regnet  
zhangwenwei committed
77
78
        return

79
    def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
80
81
82
83
84
85
86
87
        """Forward function of backbone.

        Args:
            x (torch.Tensor): Features in shape (N, C, H, W).

        Returns:
            tuple[torch.Tensor]: Multi-scale features.
        """
zhangwenwei's avatar
Regnet  
zhangwenwei committed
88
89
90
91
92
93
94
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        return tuple(outs)