Unverified Commit 26aba2f5 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix an incorrect assignment in the BaseTransformerLayer (#1670)

* fix BaseTransformerLayer

* Add BaseTransformerLayer unit test without ffn_cfg embed_dims

* Optimize code logic
parent 26c095dc
...@@ -745,7 +745,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -745,7 +745,7 @@ class BaseTransformerLayer(BaseModule):
assert len(ffn_cfgs) == num_ffns assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns): for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]: if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else: else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append( self.ffns.append(
......
...@@ -575,8 +575,27 @@ def test_basetransformerlayer_cuda(): ...@@ -575,8 +575,27 @@ def test_basetransformerlayer_cuda():
assert x.shape == torch.Size([2, 10, 256]) assert x.shape == torch.Size([2, 10, 256])
def test_basetransformerlayer(): @pytest.mark.parametrize('embed_dims', [False, 256])
def test_basetransformerlayer(embed_dims):
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
if embed_dims:
ffn_cfgs = dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
)
else:
ffn_cfgs = dict(
type='FFN',
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
)
feedforward_channels = 2048 feedforward_channels = 2048
ffn_dropout = 0.1 ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm') operation_order = ('self_attn', 'norm', 'ffn', 'norm')
...@@ -584,6 +603,7 @@ def test_basetransformerlayer(): ...@@ -584,6 +603,7 @@ def test_basetransformerlayer():
# test deprecated_args # test deprecated_args
baselayer = BaseTransformerLayer( baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs, attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
feedforward_channels=feedforward_channels, feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout, ffn_dropout=ffn_dropout,
operation_order=operation_order) operation_order=operation_order)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment