Unverified Commit 926ac07b authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

Move fuse conv bn to mmcv (#382)

* move fuse conv bn to mmcv

* update doc

* update test conv bn

* rename

* fix doc and variable name

* change func name
parent d6789862
...@@ -10,8 +10,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -10,8 +10,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
is_norm) is_norm)
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init, from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
get_model_complexity_info, kaiming_init, normal_init, fuse_conv_bn, get_model_complexity_info, kaiming_init,
uniform_init, xavier_init) normal_init, uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer from .vgg import VGG, make_vgg_layer
__all__ = [ __all__ = [
...@@ -24,5 +24,5 @@ __all__ = [ ...@@ -24,5 +24,5 @@ __all__ = [
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d' 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (bias_init_with_prob, caffe2_xavier_init, from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
constant_init, kaiming_init, normal_init, constant_init, kaiming_init, normal_init,
uniform_init, xavier_init) uniform_init, xavier_init)
...@@ -7,5 +8,5 @@ from .weight_init import (bias_init_with_prob, caffe2_xavier_init, ...@@ -7,5 +8,5 @@ from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
__all__ = [ __all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init' 'xavier_init', 'fuse_conv_bn'
] ]
import torch
import torch.nn as nn
def _fuse_conv_bn(conv, bn):
"""Fuse conv and bn into one module.
Args:
conv (nn.Module): Conv to be fused.
bn (nn.Module): BN to be fused.
Returns:
nn.Module: Fused module.
"""
conv_w = conv.weight
conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
bn.running_mean)
factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
conv.weight = nn.Parameter(conv_w *
factor.reshape([conv.out_channels, 1, 1, 1]))
conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
return conv
def fuse_conv_bn(module):
"""Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off
but only the mean and var alone channels are used, which exposes the
chance to fuse it with the preceding conv layers to save computations and
simplify network structures.
Args:
module (nn.Module): Module to be fused.
Returns:
nn.Module: Fused module.
"""
last_conv = None
last_conv_name = None
for name, child in module.named_children():
if isinstance(child,
(nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
if last_conv is None: # only fuse BN that is after Conv
continue
fused_conv = _fuse_conv_bn(last_conv, child)
module._modules[last_conv_name] = fused_conv
# To reduce changes, set BN as Identity instead of deleting it.
module._modules[name] = nn.Identity()
last_conv = None
elif isinstance(child, nn.Conv2d):
last_conv = child
last_conv_name = name
else:
fuse_conv_bn(child)
return module
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, fuse_conv_bn
def test_fuse_conv_bn():
inputs = torch.rand((1, 3, 5, 5))
modules = nn.ModuleList()
modules.append(nn.BatchNorm2d(3))
modules.append(ConvModule(3, 5, 3, norm_cfg=dict(type='BN')))
modules.append(ConvModule(5, 5, 3, norm_cfg=dict(type='BN')))
modules = nn.Sequential(*modules)
fused_modules = fuse_conv_bn(modules)
assert torch.equal(modules(inputs), fused_modules(inputs))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment