# modified from https://github.com/Haiyang-W/DSVT import warnings from typing import Optional, Sequence, Tuple from mmengine.model import BaseModule from torch import Tensor from torch import nn as nn from mmdet3d.registry import MODELS from mmdet3d.utils import OptMultiConfig class BasicResBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, padding: int = 1, downsample: bool = False, ) -> None: super().__init__() self.conv1 = nn.Conv2d( inplanes, planes, kernel_size=3, stride=stride, padding=padding, bias=False) self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01) self.relu2 = nn.ReLU() self.downsample = downsample if self.downsample: self.downsample_layer = nn.Sequential( nn.Conv2d( inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False), nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)) self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) if self.downsample: identity = self.downsample_layer(x) out += identity out = self.relu2(out) return out @MODELS.register_module() class ResSECOND(BaseModule): """Backbone network for DSVT. The difference between `ResSECOND` and `SECOND` is that the basic block in this module contains residual layers. Args: in_channels (int): Input channels. out_channels (list[int]): Output channels for multi-scale feature maps. blocks_nums (list[int]): Number of blocks in each stage. layer_strides (list[int]): Strides of each stage. norm_cfg (dict): Config dict of normalization layers. conv_cfg (dict): Config dict of convolutional layers. """ def __init__(self, in_channels: int = 128, out_channels: Sequence[int] = [128, 128, 256], blocks_nums: Sequence[int] = [1, 2, 2], layer_strides: Sequence[int] = [2, 2, 2], init_cfg: OptMultiConfig = None, pretrained: Optional[str] = None) -> None: super(ResSECOND, self).__init__(init_cfg=init_cfg) assert len(layer_strides) == len(blocks_nums) assert len(out_channels) == len(blocks_nums) in_filters = [in_channels, *out_channels[:-1]] blocks = [] for i, block_num in enumerate(blocks_nums): cur_layers = [ BasicResBlock( in_filters[i], out_channels[i], stride=layer_strides[i], downsample=True) ] for _ in range(block_num): cur_layers.append( BasicResBlock(out_channels[i], out_channels[i])) blocks.append(nn.Sequential(*cur_layers)) self.blocks = nn.Sequential(*blocks) assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be setting at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) else: self.init_cfg = dict(type='Kaiming', layer='Conv2d') def forward(self, x: Tensor) -> Tuple[Tensor, ...]: """Forward function. Args: x (torch.Tensor): Input with shape (N, C, H, W). Returns: tuple[torch.Tensor]: Multi-scale features. """ outs = [] for i in range(len(self.blocks)): x = self.blocks[i](x) outs.append(x) return tuple(outs)