Commit 618dca08 authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Fix AnchorHead in_channels (#1506)

* test that all configs can be loaded

* Use in_channels correctly in anchor_head and guided_anchor_head

* Fix lint errors. Only tests a subset of configs

* remove local config

* fix yapf

* Remove slower tests

* Remove debug code

* trigger travis
parent 9d767a03
......@@ -19,7 +19,7 @@ class AnchorHead(nn.Module):
num_classes (int): Number of categories including the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
......@@ -47,7 +47,6 @@ class AnchorHead(nn.Module):
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
super(AnchorHead, self).__init__()
# NOTE: in_channels is only used in child classes (e.g. RetinaHead)
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
......@@ -82,9 +81,9 @@ class AnchorHead(nn.Module):
self._init_layers()
def _init_layers(self):
self.conv_cls = nn.Conv2d(self.feat_channels,
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.conv_cls, std=0.01)
......@@ -231,8 +230,7 @@ class AnchorHead(nn.Module):
Example:
>>> import mmcv
>>> self = AnchorHead(num_classes=9, in_channels=1,
>>> feat_channels=1)
>>> self = AnchorHead(num_classes=9, in_channels=1)
>>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
>>> cfg = mmcv.Config(dict(
>>> score_thr=0.00,
......
......@@ -72,7 +72,7 @@ class GuidedAnchorHead(AnchorHead):
Args:
num_classes (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
feat_channels (int): Number of hidden channels.
octave_base_scale (int): Base octave scale of each level of
feature map.
scales_per_octave (int): Number of octave scales in each level of
......@@ -170,11 +170,10 @@ class GuidedAnchorHead(AnchorHead):
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
1)
self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.in_channels, self.num_anchors * 2, 1)
self.feature_adaption = FeatureAdaption(
self.feat_channels,
self.in_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
......
from os.path import dirname, exists, join
def _get_config_directory():
""" Find the predefined detector config directory """
try:
# Assume we are running in the source mmdetection repo
repo_dpath = dirname(dirname(__file__))
except NameError:
# For IPython development when this __file__ is not defined
import mmdet
repo_dpath = dirname(dirname(mmdet.__file__))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def test_config_build_detector():
"""
Test that all detection models defined in the configs can be initialized.
"""
from xdoctest.utils import import_module_from_path
from mmdet.models import build_detector
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
# import glob
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
# Only tests a representative subset of configurations
config_names = [
# 'dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py',
# 'dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py',
# 'dcn/faster_rcnn_dpool_r50_fpn_1x.py',
'dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py',
# 'dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py',
# 'dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py',
# 'dcn/faster_rcnn_mdpool_r50_fpn_1x.py',
# 'dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py',
# 'dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py',
# ---
# 'htc/htc_x101_32x4d_fpn_20e_16gpu.py',
'htc/htc_without_semantic_r50_fpn_1x.py',
# 'htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py',
# 'htc/htc_x101_64x4d_fpn_20e_16gpu.py',
# 'htc/htc_r50_fpn_1x.py',
# 'htc/htc_r101_fpn_20e.py',
# 'htc/htc_r50_fpn_20e.py',
# ---
'cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py',
# 'cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py',
# ---
# 'scratch/scratch_faster_rcnn_r50_fpn_gn_6x.py',
# 'scratch/scratch_mask_rcnn_r50_fpn_gn_6x.py',
# ---
# 'grid_rcnn/grid_rcnn_gn_head_x101_32x4d_fpn_2x.py',
'grid_rcnn/grid_rcnn_gn_head_r50_fpn_2x.py',
# ---
'double_heads/dh_faster_rcnn_r50_fpn_1x.py',
# ---
'empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py',
# 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x.py',
# 'empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x.py',
# 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py',
# ---
# 'ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py',
# 'ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py',
# 'ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py',
# ---
# 'guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py',
# 'guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py',
# 'guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py',
# 'guided_anchoring/ga_fast_r50_caffe_fpn_1x.py',
# 'guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py',
# 'guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py',
# 'guided_anchoring/ga_faster_r50_caffe_fpn_1x.py',
'guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py',
# ---
'foveabox/fovea_r50_fpn_4gpu_1x.py',
# 'foveabox/fovea_align_gn_ms_r101_fpn_4gpu_2x.py',
# 'foveabox/fovea_align_gn_r50_fpn_4gpu_2x.py',
# 'foveabox/fovea_align_gn_r101_fpn_4gpu_2x.py',
'foveabox/fovea_align_gn_ms_r50_fpn_4gpu_2x.py',
# ---
# 'hrnet/cascade_rcnn_hrnetv2p_w32_20e.py',
# 'hrnet/mask_rcnn_hrnetv2p_w32_1x.py',
# 'hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py',
# 'hrnet/htc_hrnetv2p_w32_20e.py',
# 'hrnet/faster_rcnn_hrnetv2p_w18_1x.py',
# 'hrnet/mask_rcnn_hrnetv2p_w18_1x.py',
# 'hrnet/faster_rcnn_hrnetv2p_w32_1x.py',
# 'hrnet/faster_rcnn_hrnetv2p_w40_1x.py',
'hrnet/fcos_hrnetv2p_w32_gn_1x_4gpu.py',
# ---
# 'gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py',
# 'gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py',
'gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py',
# 'gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py',
# ---
# 'wider_face/ssd300_wider_face.py',
# ---
'pascal_voc/ssd300_voc.py',
'pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py',
'pascal_voc/ssd512_voc.py',
# ---
# 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_syncbn_1x.py',
# 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_syncbn_1x.py',
# 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_1x.py',
# 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_1x.py',
'gcnet/mask_rcnn_r50_fpn_sbn_1x.py',
# ---
'gn/mask_rcnn_r50_fpn_gn_contrib_2x.py',
# 'gn/mask_rcnn_r50_fpn_gn_2x.py',
# 'gn/mask_rcnn_r101_fpn_gn_2x.py',
# ---
# 'reppoints/reppoints_moment_x101_dcn_fpn_2x.py',
'reppoints/reppoints_moment_r50_fpn_2x.py',
# 'reppoints/reppoints_moment_x101_dcn_fpn_2x_mt.py',
'reppoints/reppoints_partial_minmax_r50_fpn_1x.py',
'reppoints/bbox_r50_grid_center_fpn_1x.py',
# 'reppoints/reppoints_moment_r101_dcn_fpn_2x.py',
# 'reppoints/reppoints_moment_r101_fpn_2x_mt.py',
# 'reppoints/reppoints_moment_r50_fpn_2x_mt.py',
'reppoints/reppoints_minmax_r50_fpn_1x.py',
# 'reppoints/reppoints_moment_r50_fpn_1x.py',
# 'reppoints/reppoints_moment_r101_fpn_2x.py',
# 'reppoints/reppoints_moment_r101_dcn_fpn_2x_mt.py',
'reppoints/bbox_r50_grid_fpn_1x.py',
# ---
# 'fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py',
# 'fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py',
'fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py',
# ---
'albu_example/mask_rcnn_r50_fpn_1x.py',
# ---
'libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py',
# 'libra_rcnn/libra_retinanet_r50_fpn_1x.py',
# 'libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py',
# 'libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py',
# 'libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py',
# ---
# 'ghm/retinanet_ghm_r50_fpn_1x.py',
# ---
# 'fp16/retinanet_r50_fpn_fp16_1x.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x.py',
'fp16/faster_rcnn_r50_fpn_fp16_1x.py'
]
print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
config_mod = import_module_from_path(config_fpath)
config_mod.model
config_mod.train_cfg
config_mod.test_cfg
print('Building detector, config_fpath = {!r}'.format(config_fpath))
# Remove pretrained keys to allow for testing in an offline environment
if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None
detector = build_detector(
config_mod.model,
train_cfg=config_mod.train_cfg,
test_cfg=config_mod.test_cfg)
assert detector is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment