Commit a9e21cf7 authored by myownskyW7's avatar myownskyW7 Committed by Kai Chen
Browse files

Support models without FPN (#133)

* add two stage w/o neck and w/ upperneck

* add rpn r50 c4

* update c4 configs

* fix

* config update

* update config

* minor update

* mask rcnn support c4 train and test

* lr fix

* cascade support upper_neck

* add cascade c4 config

* update config

* update

* update res_layer to new interface

* refactoring

* c4 configs update

* refactoring

* update rpn_c4 config

* rename upper_neck as shared_head

* update

* update configs

* update

* update c4 configs

* update according to commits

* update
parent 90096804
import logging
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from ..backbones import ResNet, make_res_layer
from ..registry import SHARED_HEADS
@SHARED_HEADS.register_module
class ResLayer(nn.Module):
def __init__(self,
depth,
stage=3,
stride=2,
dilation=1,
style='pytorch',
normalize=dict(type='BN', frozen=False),
norm_eval=True,
with_cp=False,
dcn=None):
super(ResLayer, self).__init__()
self.norm_eval = norm_eval
self.normalize = normalize
self.stage = stage
block, stage_blocks = ResNet.arch_settings[depth]
stage_block = stage_blocks[stage]
planes = 64 * 2**stage
inplanes = 64 * 2**(stage - 1) * block.expansion
res_layer = make_res_layer(
block,
inplanes,
planes,
stage_block,
stride=stride,
dilation=dilation,
style=style,
with_cp=with_cp,
normalize=self.normalize,
dcn=dcn)
self.add_module('layer{}'.format(stage + 1), res_layer)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
res_layer = getattr(self, 'layer{}'.format(self.stage + 1))
out = res_layer(x)
return out
def train(self, mode=True):
super(ResLayer, self).train(mode)
if self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment