"docs/vscode:/vscode.git/clone" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Unverified Commit f64d4858 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

rename fast_conv_bn_eval to efficient_conv_bn_eval (#2884)

parent ad7284e8
...@@ -15,7 +15,8 @@ from .norm import build_norm_layer ...@@ -15,7 +15,8 @@ from .norm import build_norm_layer
from .padding import build_padding_layer from .padding import build_padding_layer
def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, def efficient_conv_bn_eval_forward(bn: _BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor): x: torch.Tensor):
""" """
Implementation based on https://arxiv.org/abs/2305.11624 Implementation based on https://arxiv.org/abs/2305.11624
...@@ -115,9 +116,9 @@ class ConvModule(nn.Module): ...@@ -115,9 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). ("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act'). Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive efficient_conv_bn_eval (bool): Whether use efficient conv when the
bn is in eval mode (either training or testing), as proposed in consecutive bn is in eval mode (either training or testing), as
https://arxiv.org/abs/2305.11624 . Default: False. proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
""" """
_abbr_ = 'conv_block' _abbr_ = 'conv_block'
...@@ -138,7 +139,7 @@ class ConvModule(nn.Module): ...@@ -138,7 +139,7 @@ class ConvModule(nn.Module):
with_spectral_norm: bool = False, with_spectral_norm: bool = False,
padding_mode: str = 'zeros', padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act'), order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False): efficient_conv_bn_eval: bool = False):
super().__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
...@@ -209,7 +210,7 @@ class ConvModule(nn.Module): ...@@ -209,7 +210,7 @@ class ConvModule(nn.Module):
else: else:
self.norm_name = None # type: ignore self.norm_name = None # type: ignore
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval) self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
# build activation layer # build activation layer
if self.with_activation: if self.with_activation:
...@@ -263,15 +264,16 @@ class ConvModule(nn.Module): ...@@ -263,15 +264,16 @@ class ConvModule(nn.Module):
if self.with_explicit_padding: if self.with_explicit_padding:
x = self.padding_layer(x) x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in # if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv # eval mode and we have enabled `efficient_conv_bn_eval` for
# operator, then activate the optimized forward and skip the # the conv operator, then activate the optimized forward and
# next norm operator since it has been fused # skip the next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \ if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \ self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \ self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None: self.efficient_conv_bn_eval_forward is not None:
self.conv.forward = partial(self.fast_conv_bn_eval_forward, self.conv.forward = partial(
self.norm, self.conv) self.efficient_conv_bn_eval_forward, self.norm,
self.conv)
layer_index += 1 layer_index += 1
x = self.conv(x) x = self.conv(x)
del self.conv.forward del self.conv.forward
...@@ -284,20 +286,20 @@ class ConvModule(nn.Module): ...@@ -284,20 +286,20 @@ class ConvModule(nn.Module):
layer_index += 1 layer_index += 1
return x return x
def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True): def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
# fast_conv_bn_eval works for conv + bn # efficient_conv_bn_eval works for conv + bn
# with `track_running_stats` option # with `track_running_stats` option
if fast_conv_bn_eval and self.norm \ if efficient_conv_bn_eval and self.norm \
and isinstance(self.norm, _BatchNorm) \ and isinstance(self.norm, _BatchNorm) \
and self.norm.track_running_stats: and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward # noqa: E501
else: else:
self.fast_conv_bn_eval_forward = None # type: ignore self.efficient_conv_bn_eval_forward = None # type: ignore
@staticmethod @staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd, def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm, bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule': efficient_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module.""" """Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule) self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
...@@ -331,6 +333,6 @@ class ConvModule(nn.Module): ...@@ -331,6 +333,6 @@ class ConvModule(nn.Module):
self.norm_name, norm = 'bn', bn self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm) self.add_module(self.norm_name, norm)
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval) self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
return self return self
...@@ -75,27 +75,30 @@ def test_conv_module(): ...@@ -75,27 +75,30 @@ def test_conv_module():
output = conv(x) output = conv(x)
assert output.shape == (1, 8, 255, 255) assert output.shape == (1, 8, 255, 255)
# conv + norm with fast mode # conv + norm with efficient mode
fast_conv = ConvModule( efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval() 3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
plain_conv = ConvModule( plain_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval() 3, 8, 2, norm_cfg=dict(type='BN'),
for fast_param, plain_param in zip(fast_conv.state_dict().values(), efficient_conv_bn_eval=False).eval()
for efficient_param, plain_param in zip(
efficient_conv.state_dict().values(),
plain_conv.state_dict().values()): plain_conv.state_dict().values()):
plain_param.copy_(fast_param) plain_param.copy_(efficient_param)
fast_mode_output = fast_conv(x) efficient_mode_output = efficient_conv(x)
plain_mode_output = plain_conv(x) plain_mode_output = plain_conv(x)
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5) assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)
# `conv` attribute can be dynamically modified in fast mode # `conv` attribute can be dynamically modified in efficient mode
fast_conv = ConvModule( efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval() 3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
new_conv = nn.Conv2d(3, 8, 2).eval() new_conv = nn.Conv2d(3, 8, 2).eval()
fast_conv.conv = new_conv efficient_conv.conv = new_conv
fast_mode_output = fast_conv(x) efficient_mode_output = efficient_conv(x)
plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x))) plain_mode_output = efficient_conv.activate(
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5) efficient_conv.norm(new_conv(x)))
assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)
# conv + act # conv + act
conv = ConvModule(3, 8, 2) conv = ConvModule(3, 8, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment