Commit 7cee8797 authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Add in_channels kwarg to backbones. (#1475)

* Add in_channels kwarg to backbones

* explicit import in HRNet doctest
parent 4d9a5f47
...@@ -201,6 +201,7 @@ class HRNet(nn.Module): ...@@ -201,6 +201,7 @@ class HRNet(nn.Module):
Args: Args:
extra (dict): detailed configuration for each stage of HRNet. extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer. conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer. norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely, norm_eval (bool): Whether to set norm layers to eval mode, namely,
...@@ -210,12 +211,52 @@ class HRNet(nn.Module): ...@@ -210,12 +211,52 @@ class HRNet(nn.Module):
memory while slowing down the training speed. memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity. in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
""" """
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self, def __init__(self,
extra, extra,
in_channels=3,
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
norm_eval=True, norm_eval=True,
...@@ -235,7 +276,7 @@ class HRNet(nn.Module): ...@@ -235,7 +276,7 @@ class HRNet(nn.Module):
self.conv1 = build_conv_layer( self.conv1 = build_conv_layer(
self.conv_cfg, self.conv_cfg,
3, in_channels,
64, 64,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
......
...@@ -335,6 +335,7 @@ class ResNet(nn.Module): ...@@ -335,6 +335,7 @@ class ResNet(nn.Module):
Args: Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4. num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage. strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage. dilations (Sequence[int]): Dilation of each stage.
...@@ -378,6 +379,7 @@ class ResNet(nn.Module): ...@@ -378,6 +379,7 @@ class ResNet(nn.Module):
def __init__(self, def __init__(self,
depth, depth,
in_channels=3,
num_stages=4, num_stages=4,
strides=(1, 2, 2, 2), strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
...@@ -426,7 +428,7 @@ class ResNet(nn.Module): ...@@ -426,7 +428,7 @@ class ResNet(nn.Module):
self.stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64 self.inplanes = 64
self._make_stem_layer() self._make_stem_layer(in_channels)
self.res_layers = [] self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks): for i, num_blocks in enumerate(self.stage_blocks):
...@@ -464,10 +466,10 @@ class ResNet(nn.Module): ...@@ -464,10 +466,10 @@ class ResNet(nn.Module):
def norm1(self): def norm1(self):
return getattr(self, self.norm1_name) return getattr(self, self.norm1_name)
def _make_stem_layer(self): def _make_stem_layer(self, in_channels):
self.conv1 = build_conv_layer( self.conv1 = build_conv_layer(
self.conv_cfg, self.conv_cfg,
3, in_channels,
64, 64,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
......
...@@ -160,6 +160,7 @@ class ResNeXt(ResNet): ...@@ -160,6 +160,7 @@ class ResNeXt(ResNet):
Args: Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4. num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext. groups (int): Group of resnext.
base_width (int): Base width of resnext. base_width (int): Base width of resnext.
......
...@@ -11,6 +11,26 @@ from ..registry import BACKBONES ...@@ -11,6 +11,26 @@ from ..registry import BACKBONES
@BACKBONES.register_module @BACKBONES.register_module
class SSDVGG(VGG): class SSDVGG(VGG):
"""VGG Backbone network for single-shot-detection
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
out_indices (Sequence[int]): Output from which stages.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 300, 300)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 19, 19)
(1, 512, 10, 10)
(1, 256, 5, 5)
(1, 256, 3, 3)
(1, 256, 1, 1)
"""
extra_setting = { extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256), 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128), 512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
...@@ -24,6 +44,7 @@ class SSDVGG(VGG): ...@@ -24,6 +44,7 @@ class SSDVGG(VGG):
out_indices=(3, 4), out_indices=(3, 4),
out_feature_indices=(22, 34), out_feature_indices=(22, 34),
l2_norm_scale=20.): l2_norm_scale=20.):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__( super(SSDVGG, self).__init__(
depth, depth,
with_last_pool=with_last_pool, with_last_pool=with_last_pool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment