Unverified Commit 36003b76 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and...

[Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode (#2807)
parent f01d301e
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from functools import partial
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
...@@ -14,6 +15,55 @@ from .norm import build_norm_layer ...@@ -14,6 +15,55 @@ from .norm import build_norm_layer
from .padding import build_padding_layer from .padding import build_padding_layer
def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)
if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)
if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)
# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
@MODELS.register_module() @MODELS.register_module()
class ConvModule(nn.Module): class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers. """A conv block that bundles conv/norm/activation layers.
...@@ -65,6 +115,9 @@ class ConvModule(nn.Module): ...@@ -65,6 +115,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). ("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act'). Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive
bn is in eval mode (either training or testing), as proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
""" """
_abbr_ = 'conv_block' _abbr_ = 'conv_block'
...@@ -84,7 +137,8 @@ class ConvModule(nn.Module): ...@@ -84,7 +137,8 @@ class ConvModule(nn.Module):
inplace: bool = True, inplace: bool = True,
with_spectral_norm: bool = False, with_spectral_norm: bool = False,
padding_mode: str = 'zeros', padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')): order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False):
super().__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
...@@ -155,6 +209,16 @@ class ConvModule(nn.Module): ...@@ -155,6 +209,16 @@ class ConvModule(nn.Module):
else: else:
self.norm_name = None # type: ignore self.norm_name = None # type: ignore
# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward
# build activation layer # build activation layer
if self.with_activation: if self.with_activation:
act_cfg_ = act_cfg.copy() # type: ignore act_cfg_ = act_cfg.copy() # type: ignore
...@@ -200,13 +264,77 @@ class ConvModule(nn.Module): ...@@ -200,13 +264,77 @@ class ConvModule(nn.Module):
x: torch.Tensor, x: torch.Tensor,
activate: bool = True, activate: bool = True,
norm: bool = True) -> torch.Tensor: norm: bool = True) -> torch.Tensor:
for layer in self.order: layer_index = 0
while layer_index < len(self.order):
layer = self.order[layer_index]
if layer == 'conv': if layer == 'conv':
if self.with_explicit_padding: if self.with_explicit_padding:
x = self.padding_layer(x) x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv
# operator, then activate the optimized forward and skip the
# next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = self.fast_conv_bn_eval_forward
layer_index += 1
else:
self.conv.forward = self.original_conv_forward
x = self.conv(x) x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm: elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x) x = self.norm(x)
elif layer == 'act' and activate and self.with_activation: elif layer == 'act' and activate and self.with_activation:
x = self.activate(x) x = self.activate(x)
layer_index += 1
return x return x
@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()
self.conv_cfg = None
self.norm_cfg = None
self.act_cfg = None
self.inplace = False
self.with_spectral_norm = False
self.with_explicit_padding = False
self.order = ('conv', 'norm', 'act')
self.with_norm = True
self.with_activation = False
self.with_bias = conv.bias is not None
# build convolution layer
self.conv = conv
# export the attributes of self.conv to a higher level for convenience
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups
# build normalization layers
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)
# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward
return self
...@@ -75,6 +75,16 @@ def test_conv_module(): ...@@ -75,6 +75,16 @@ def test_conv_module():
output = conv(x) output = conv(x)
assert output.shape == (1, 8, 255, 255) assert output.shape == (1, 8, 255, 255)
# conv + norm with fast mode
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True)
conv.norm.eval()
x = torch.rand(1, 3, 256, 256)
fast_mode_output = conv(x)
conv.conv.forward = conv.original_conv_forward
plain_implementation = conv.activate(conv.norm(conv.conv(x)))
assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5)
# conv + act # conv + act
conv = ConvModule(3, 8, 2) conv = ConvModule(3, 8, 2)
assert conv.with_activation assert conv.with_activation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment