Unverified Commit 921e433e authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Support normalization layers in RetinaHead (#557)

* support GN in RetinaHead

* add conv_cfg argument for RetinaHead

* add model upgrading tool and update retinanet model urls

* minor fix for regex strings
parent 51904cbc
......@@ -109,15 +109,15 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50-FPN | caffe | 1x | 6.7 | 0.468 | 9.4 | 35.8 | - |
| R-50-FPN | pytorch | 1x | 6.9 | 0.496 | 9.1 | 35.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_1x_20181125-3d3c2142.pth) |
| R-50-FPN | pytorch | 2x | 6.9 | 0.496 | 9.1 | 36.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_2x_20181125-e0dbec97.pth) |
| R-50-FPN | pytorch | 1x | 6.9 | 0.496 | 9.1 | 35.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_1x_20181125-7b0c2548.pth) |
| R-50-FPN | pytorch | 2x | 6.9 | 0.496 | 9.1 | 36.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_2x_20181125-8b724df2.pth) |
| R-101-FPN | caffe | 1x | 9.2 | 0.614 | 8.2 | 37.8 | - |
| R-101-FPN | pytorch | 1x | 9.6 | 0.643 | 8.1 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_1x_20181129-f738a02f.pth) |
| R-101-FPN | pytorch | 2x | 9.6 | 0.643 | 8.1 | 38.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_2x_20181129-f654534b.pth) |
| X-101-32x4d-FPN | pytorch | 1x| 10.8 | 0.792 | 6.7 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_1x_20181218-c140fb82.pth)
| X-101-32x4d-FPN | pytorch | 2x| 10.8 | 0.792 | 6.7 | 39.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_2x_20181218-605dcd0a.pth)
| X-101-64x4d-FPN | pytorch | 1x| 14.6 | 1.128 | 5.3 | 40.0 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_1x_20181218-2f6f778b.pth)
| X-101-64x4d-FPN | pytorch | 2x| 14.6 | 1.128 | 5.3 | 39.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_2x_20181218-2f598dc5.pth)
| R-101-FPN | pytorch | 1x | 9.6 | 0.643 | 8.1 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_1x_20181129-f016f384.pth) |
| R-101-FPN | pytorch | 2x | 9.6 | 0.643 | 8.1 | 38.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r101_fpn_2x_20181129-72c14526.pth) |
| X-101-32x4d-FPN | pytorch | 1x| 10.8 | 0.792 | 6.7 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_1x_20181218-c84d7dfc.pth)
| X-101-32x4d-FPN | pytorch | 2x| 10.8 | 0.792 | 6.7 | 39.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_32x4d_fpn_2x_20181218-8596452d.pth)
| X-101-64x4d-FPN | pytorch | 1x| 14.6 | 1.128 | 5.3 | 40.0 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_1x_20181218-a0a22662.pth)
| X-101-64x4d-FPN | pytorch | 2x| 14.6 | 1.128 | 5.3 | 39.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_x101_64x4d_fpn_2x_20181218-5e88d045.pth)
### Cascade R-CNN
......
......@@ -4,7 +4,7 @@ from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob
from ..utils import bias_init_with_prob, ConvModule
@HEADS.register_module
......@@ -16,10 +16,14 @@ class RetinaHead(AnchorHead):
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
conv_cfg=None,
normalize=None,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.conv_cfg = conv_cfg
self.normalize = normalize
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
......@@ -38,9 +42,25 @@ class RetinaHead(AnchorHead):
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
normalize=self.normalize,
bias=self.normalize is None))
self.reg_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
normalize=self.normalize,
bias=self.normalize is None))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
......@@ -51,9 +71,9 @@ class RetinaHead(AnchorHead):
def init_weights(self):
for m in self.cls_convs:
normal_init(m, std=0.01)
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
......@@ -62,9 +82,9 @@ class RetinaHead(AnchorHead):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
import argparse
import re
from collections import OrderedDict
import torch
def convert(in_file, out_file):
"""Convert keys in checkpoints.
There can be some breaking changes during the development of mmdetection,
and this tool is used for upgrading checkpoints trained with old versions
to the latest one.
"""
checkpoint = torch.load(in_file)
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
for key, val in in_state_dict.items():
# Use ConvModule instead of nn.Conv2d in RetinaNet
# cls_convs.0.weight -> cls_convs.0.conv.weight
m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key)
if m is not None:
param = m.groups()[1]
new_key = key.replace(param, 'conv.{}'.format(param))
out_state_dict[new_key] = val
continue
out_state_dict[key] = val
checkpoint['state_dict'] = out_state_dict
torch.save(checkpoint, out_file)
def main():
parser = argparse.ArgumentParser(description='Upgrade model version')
parser.add_argument('in_file', help='input checkpoint file')
parser.add_argument('out_file', help='output checkpoint file')
args = parser.parse_args()
convert(args.in_file, args.out_file)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment