Commit e7466395 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Support FP16 training (#520)

* add fp16 support

* fpn donot need bn normalize

* refactor wrapped bn

* fix bug of retinanet

* add fp16 ssd300 voc, cascade r50, cascade mask r50

* fix bug in cascade rcnn testing

* add support to fix bn training

* add fix bn cfg

* delete fixbn cfg, mv fixbn fp16 to a new branch

* fix cascade mask fp16 bug in test

* fix bug in cascade mask rcnn fp16 test

* add more fp16 cfgs

* add fp16 fast-r50 and faster-dconv-r50

* add fp16 test, minor fix

* clean code

* fix config work_dir name

* add patch func, refactor code

* fix format

* clean code

* move convert rois to single_level_extractor

* fix bug for cascade mask, the seg mask is ndarray

* refactor code, add two decorator force_fp32 and auto_fp16

* add fp16_enable attribute

* add more comment, fix format and test assertion

* fix pep8 format error

* format commont and api

* rename distribute to distributed, fix dict copy

* rename function name

* move function, add comment

* remove unused parameter

* mv decorators into decorators.py, hook related functions to hook

* add auto_fp16 to forward of semantic head

* add auto_fp16 to all heads and fpn

* add docstrings and minor bug fix

* simple refactoring

* bug fix for patching forward method

* roi extractor in fp32 mode

* fix flake8 error

* fix ci error

* add fp16 support to ga head

* remove parallel test assert

* minor fix

* add comment in build_optimizer

* fix typo in comment

* fix typo enable --> enabling

* udpate README
parent bffa0510
...@@ -2,6 +2,7 @@ import torch.nn as nn ...@@ -2,6 +2,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmdet.core import auto_fp16
from ..registry import NECKS from ..registry import NECKS
from ..utils import ConvModule from ..utils import ConvModule
...@@ -29,6 +30,7 @@ class FPN(nn.Module): ...@@ -29,6 +30,7 @@ class FPN(nn.Module):
self.num_outs = num_outs self.num_outs = num_outs
self.activation = activation self.activation = activation
self.relu_before_extra_convs = relu_before_extra_convs self.relu_before_extra_convs = relu_before_extra_convs
self.fp16_enabled = False
if end_level == -1: if end_level == -1:
self.backbone_end_level = self.num_ins self.backbone_end_level = self.num_ins
...@@ -94,6 +96,7 @@ class FPN(nn.Module): ...@@ -94,6 +96,7 @@ class FPN(nn.Module):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
@auto_fp16()
def forward(self, inputs): def forward(self, inputs):
assert len(inputs) == len(self.in_channels) assert len(inputs) == len(self.in_channels)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from mmdet.core import force_fp32
from ..registry import ROI_EXTRACTORS from ..registry import ROI_EXTRACTORS
...@@ -31,6 +32,7 @@ class SingleRoIExtractor(nn.Module): ...@@ -31,6 +32,7 @@ class SingleRoIExtractor(nn.Module):
self.out_channels = out_channels self.out_channels = out_channels
self.featmap_strides = featmap_strides self.featmap_strides = featmap_strides
self.finest_scale = finest_scale self.finest_scale = finest_scale
self.fp16_enabled = False
@property @property
def num_inputs(self): def num_inputs(self):
...@@ -70,6 +72,7 @@ class SingleRoIExtractor(nn.Module): ...@@ -70,6 +72,7 @@ class SingleRoIExtractor(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls return target_lvls
@force_fp32(apply_to=('feats',), out_fp16=True)
def forward(self, feats, rois): def forward(self, feats, rois):
if len(feats) == 1: if len(feats) == 1:
return self.roi_layers[0](feats[0], rois) return self.roi_layers[0](feats[0], rois)
...@@ -77,8 +80,8 @@ class SingleRoIExtractor(nn.Module): ...@@ -77,8 +80,8 @@ class SingleRoIExtractor(nn.Module):
out_size = self.roi_layers[0].out_size out_size = self.roi_layers[0].out_size
num_levels = len(feats) num_levels = len(feats)
target_lvls = self.map_roi_levels(rois, num_levels) target_lvls = self.map_roi_levels(rois, num_levels)
roi_feats = torch.cuda.FloatTensor(rois.size()[0], self.out_channels, roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels,
out_size, out_size).fill_(0) out_size, out_size)
for i in range(num_levels): for i in range(num_levels):
inds = target_lvls == i inds = target_lvls == i
if inds.any(): if inds.any():
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from mmdet.core import auto_fp16
from ..backbones import ResNet, make_res_layer from ..backbones import ResNet, make_res_layer
from ..registry import SHARED_HEADS from ..registry import SHARED_HEADS
...@@ -25,6 +26,7 @@ class ResLayer(nn.Module): ...@@ -25,6 +26,7 @@ class ResLayer(nn.Module):
self.norm_eval = norm_eval self.norm_eval = norm_eval
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.stage = stage self.stage = stage
self.fp16_enabled = False
block, stage_blocks = ResNet.arch_settings[depth] block, stage_blocks = ResNet.arch_settings[depth]
stage_block = stage_blocks[stage] stage_block = stage_blocks[stage]
planes = 64 * 2**stage planes = 64 * 2**stage
...@@ -56,6 +58,7 @@ class ResLayer(nn.Module): ...@@ -56,6 +58,7 @@ class ResLayer(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
@auto_fp16()
def forward(self, x): def forward(self, x):
res_layer = getattr(self, 'layer{}'.format(self.stage + 1)) res_layer = getattr(self, 'layer{}'.format(self.stage + 1))
out = res_layer(x) out = res_layer(x)
......
...@@ -11,7 +11,7 @@ from mmcv.runner import load_checkpoint, get_dist_info ...@@ -11,7 +11,7 @@ from mmcv.runner import load_checkpoint, get_dist_info
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.apis import init_dist from mmdet.apis import init_dist
from mmdet.core import results2json, coco_eval from mmdet.core import results2json, coco_eval, wrap_fp16_model
from mmdet.datasets import build_dataloader, get_dataset from mmdet.datasets import build_dataloader, get_dataset
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -157,6 +157,9 @@ def main(): ...@@ -157,6 +157,9 @@ def main():
# build the model and load checkpoint # build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# old versions did not save class info in checkpoints, this walkaround is # old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility # for backward compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment