Unverified Commit d510b8b1 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Feature] Support LayerScale in FFN (#2451)



* add layer scale

* add layer scale

* add layer scale

* Update mmcv/cnn/bricks/transformer.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/bricks/transformer.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add layer scale

* move LayerScale

* add layer_scale_init_value

* add typehint

* fix for tensor with any dim

* fix layer scale rule

* fix layer scale rule

* fix test

* add docs
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 43360703
...@@ -31,6 +31,7 @@ Module ...@@ -31,6 +31,7 @@ Module
GeneralizedAttention GeneralizedAttention
HSigmoid HSigmoid
HSwish HSwish
LayerScale
Linear Linear
MaxPool2d MaxPool2d
MaxPool3d MaxPool3d
......
...@@ -31,6 +31,7 @@ Module ...@@ -31,6 +31,7 @@ Module
GeneralizedAttention GeneralizedAttention
HSigmoid HSigmoid
HSwish HSwish
LayerScale
Linear Linear
MaxPool2d MaxPool2d
MaxPool3d MaxPool3d
......
...@@ -14,7 +14,7 @@ from .non_local import NonLocal1d, NonLocal2d, NonLocal3d ...@@ -14,7 +14,7 @@ from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer from .padding import build_padding_layer
from .plugin import build_plugin_layer from .plugin import build_plugin_layer
from .scale import Scale from .scale import LayerScale, Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
...@@ -28,5 +28,5 @@ __all__ = [ ...@@ -28,5 +28,5 @@ __all__ = [
'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
'Conv3d', 'Dropout', 'DropPath' 'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
] ]
...@@ -19,3 +19,39 @@ class Scale(nn.Module): ...@@ -19,3 +19,39 @@ class Scale(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale return x * self.scale
class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): Whether performs operation in-place.
Default: `False`.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Default: 'channels_last'.
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last',
scale: float = 1e-5):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * scale)
def forward(self, x) -> torch.Tensor:
if self.data_format == 'channels_first':
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
else:
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
if self.inplace:
return x.mul_(self.weight.view(*shape))
else:
return x * self.weight.view(*shape)
...@@ -15,6 +15,7 @@ from mmengine.utils import deprecated_api_warning, to_2tuple ...@@ -15,6 +15,7 @@ from mmengine.utils import deprecated_api_warning, to_2tuple
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer) build_norm_layer)
from .drop import build_dropout from .drop import build_dropout
from .scale import LayerScale
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: try:
...@@ -572,6 +573,8 @@ class FFN(BaseModule): ...@@ -572,6 +573,8 @@ class FFN(BaseModule):
when adding the shortcut. when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
layer_scale_init_value (float): Initial value of scale factor in
LayerScale. Default: 1.0
""" """
@deprecated_api_warning( @deprecated_api_warning(
...@@ -588,7 +591,8 @@ class FFN(BaseModule): ...@@ -588,7 +591,8 @@ class FFN(BaseModule):
ffn_drop=0., ffn_drop=0.,
dropout_layer=None, dropout_layer=None,
add_identity=True, add_identity=True,
init_cfg=None): init_cfg=None,
layer_scale_init_value=0.):
super().__init__(init_cfg) super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
...@@ -611,6 +615,11 @@ class FFN(BaseModule): ...@@ -611,6 +615,11 @@ class FFN(BaseModule):
dropout_layer) if dropout_layer else torch.nn.Identity() dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity self.add_identity = add_identity
if layer_scale_init_value > 0:
self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
else:
self.gamma2 = nn.Identity()
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None): def forward(self, x, identity=None):
"""Forward function for `FFN`. """Forward function for `FFN`.
...@@ -618,6 +627,7 @@ class FFN(BaseModule): ...@@ -618,6 +627,7 @@ class FFN(BaseModule):
The function would add x to the output tensor if residue is None. The function would add x to the output tensor if residue is None.
""" """
out = self.layers(x) out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity: if not self.add_identity:
return self.dropout_layer(out) return self.dropout_layer(out)
if identity is None: if identity is None:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch import torch
from mmcv.cnn.bricks import Scale from mmcv.cnn.bricks import LayerScale, Scale
def test_scale(): def test_scale():
...@@ -20,3 +21,58 @@ def test_scale(): ...@@ -20,3 +21,58 @@ def test_scale():
x = torch.rand(1, 3, 64, 64) x = torch.rand(1, 3, 64, 64)
output = scale(x) output = scale(x)
assert output.shape == (1, 3, 64, 64) assert output.shape == (1, 3, 64, 64)
def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)
# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)
# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_last 2d
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 7, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 7, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)
# test channels_first 3D
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7, 7)
assert torch.equal(x * 1e-5, out)
# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out
...@@ -538,7 +538,6 @@ def test_ffn(): ...@@ -538,7 +538,6 @@ def test_ffn():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# num_fcs should be no less than 2 # num_fcs should be no less than 2
FFN(num_fcs=1) FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True) ffn = FFN(dropout=0, add_identity=True)
input_tensor = torch.rand(2, 20, 256) input_tensor = torch.rand(2, 20, 256)
...@@ -553,6 +552,13 @@ def test_ffn(): ...@@ -553,6 +552,13 @@ def test_ffn():
ffn(input_tensor, identity=residual).sum(), ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
# test with layer_scale
ffn = FFN(dropout=0, add_identity=True, layer_scale_init_value=0.1)
input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda(): def test_basetransformerlayer_cuda():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment