"vscode:/vscode.git/clone" did not exist on "a67d9e6ff1d60d458338e3af3535f6e4fa0f0268"
Unverified Commit 088fde35 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Avoid bc-breaking of importing `MultiScaleDeformableAttention` (#1100)

* avoid bc-breaking

* fix function name

* fix typo

* fix import

* add import warning

* remove wapper

* remove unitest

* add dep warning

* add dep warning
parent a5d4c652
...@@ -12,6 +12,22 @@ from .drop import build_dropout ...@@ -12,6 +12,22 @@ from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
warnings.warn(
ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
))
except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. ')
def build_positional_encoding(cfg, default_args=None): def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding.""" """Builder for Position Encoding."""
...@@ -56,9 +72,9 @@ class MultiheadAttention(BaseModule): ...@@ -56,9 +72,9 @@ class MultiheadAttention(BaseModule):
when adding the shortcut. when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
batch_first (bool): Key, Query and Value are shape of batch_first (bool): When it is True, Key, Query and Value are shape of
(batch, n, embed_dim) (batch, n, embed_dim), otherwise (n, batch, embed_dim).
or (n, batch, embed_dim). Default to False. Default to False.
""" """
def __init__(self, def __init__(self,
...@@ -88,10 +104,12 @@ class MultiheadAttention(BaseModule): ...@@ -88,10 +104,12 @@ class MultiheadAttention(BaseModule):
if self.batch_first: if self.batch_first:
def _bnc_to_nbc(forward): def _bnc_to_nbc(forward):
"""This function can adjust the shape of dataflow('key', """Because the dataflow('key', 'query', 'value') of
'query', 'value') from batch_first (batch, num_query, ``torch.nn.MultiheadAttention`` is (num_query, batch,
embed_dims) to num_query_first (num_query ,batch, embed_dims), We should adjust the shape of dataflow from
embed_dims).""" batch_first (batch, num_query, embed_dims) to num_query_first
(num_query ,batch, embed_dims), and recover ``attn_output``
from num_query_first to batch_first."""
def forward_wrapper(**kwargs): def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value') convert_keys = ('key', 'query', 'value')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment