Unverified Commit 26aba2f5 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix an incorrect assignment in the BaseTransformerLayer (#1670)

* fix BaseTransformerLayer

* Add BaseTransformerLayer unit test without ffn_cfg embed_dims

* Optimize code logic
parent 26c095dc
......@@ -745,7 +745,7 @@ class BaseTransformerLayer(BaseModule):
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
......
......@@ -575,8 +575,27 @@ def test_basetransformerlayer_cuda():
assert x.shape == torch.Size([2, 10, 256])
def test_basetransformerlayer():
@pytest.mark.parametrize('embed_dims', [False, 256])
def test_basetransformerlayer(embed_dims):
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
if embed_dims:
ffn_cfgs = dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
)
else:
ffn_cfgs = dict(
type='FFN',
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
)
feedforward_channels = 2048
ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm')
......@@ -584,6 +603,7 @@ def test_basetransformerlayer():
# test deprecated_args
baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment