Unverified Commit c522b47e authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

fix the wrong function reference bug in BaseTransformerLayer when batch_first is True (#1418)

parent 426e229d
......@@ -102,27 +102,6 @@ class MultiheadAttention(BaseModule):
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
if self.batch_first:
def _bnc_to_nbc(forward):
"""Because the dataflow('key', 'query', 'value') of
``torch.nn.MultiheadAttention`` is (num_query, batch,
embed_dims), We should adjust the shape of dataflow from
batch_first (batch, num_query, embed_dims) to num_query_first
(num_query ,batch, embed_dims), and recover ``attn_output``
from num_query_first to batch_first."""
def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
for key in kwargs.keys():
if key in convert_keys:
kwargs[key] = kwargs[key].transpose(0, 1)
attn_output, attn_output_weights = forward(**kwargs)
return attn_output.transpose(0, 1), attn_output_weights
return forward_wrapper
self.attn.forward = _bnc_to_nbc(self.attn.forward)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
......@@ -199,6 +178,17 @@ class MultiheadAttention(BaseModule):
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = self.attn(
query=query,
key=key,
......@@ -206,6 +196,9 @@ class MultiheadAttention(BaseModule):
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
......
import copy
import pytest
import torch
......@@ -5,6 +7,7 @@ from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer,
MultiheadAttention,
TransformerLayerSequence)
from mmcv.runner import ModuleList
def test_multiheadattention():
......@@ -92,6 +95,28 @@ def test_ffn():
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda():
# To test if the BaseTransformerLayer's behaviour remains
# consistent after being deepcopied
operation_order = ('self_attn', 'ffn')
baselayer = BaseTransformerLayer(
operation_order=operation_order,
batch_first=True,
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
),
)
baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)])
baselayers.to('cuda')
x = torch.rand(2, 10, 256).cuda()
for m in baselayers:
x = m(x)
assert x.shape == torch.Size([2, 10, 256])
def test_basetransformerlayer():
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
feedforward_channels = 2048
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment