Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
8d7e0a6d
"server/vscode:/vscode.git/clone" did not exist on "d461d955c3ac32efc195ec7871632162cd08e833"
Unverified
Commit
8d7e0a6d
authored
Apr 14, 2021
by
ZhangShilong
Committed by
GitHub
Apr 14, 2021
Browse files
[Refactor]: add init_cfg in transformer base classes (#946)
parent
79f8cbd6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
mmcv/cnn/bricks/transformer.py
mmcv/cnn/bricks/transformer.py
+12
-5
No files found.
mmcv/cnn/bricks/transformer.py
View file @
8d7e0a6d
import
copy
import
copy
import
warnings
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv
import
ConfigDict
from
mmcv
import
ConfigDict
...
@@ -53,7 +54,7 @@ class MultiheadAttention(BaseModule):
...
@@ -53,7 +54,7 @@ class MultiheadAttention(BaseModule):
dropout
=
0.
,
dropout
=
0.
,
init_cfg
=
None
,
init_cfg
=
None
,
**
kwargs
):
**
kwargs
):
super
(
MultiheadAttention
,
self
).
__init__
()
super
(
MultiheadAttention
,
self
).
__init__
(
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
dropout
=
dropout
...
@@ -162,7 +163,7 @@ class FFN(BaseModule):
...
@@ -162,7 +163,7 @@ class FFN(BaseModule):
dropout
=
0.
,
dropout
=
0.
,
add_residual
=
True
,
add_residual
=
True
,
init_cfg
=
None
):
init_cfg
=
None
):
super
(
FFN
,
self
).
__init__
()
super
(
FFN
,
self
).
__init__
(
init_cfg
)
assert
num_fcs
>=
2
,
'num_fcs should be no less '
\
assert
num_fcs
>=
2
,
'num_fcs should be no less '
\
f
'than 2. got
{
num_fcs
}
.'
f
'than 2. got
{
num_fcs
}
.'
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
...
@@ -193,7 +194,7 @@ class FFN(BaseModule):
...
@@ -193,7 +194,7 @@ class FFN(BaseModule):
"""
"""
out
=
self
.
layers
(
x
)
out
=
self
.
layers
(
x
)
if
not
self
.
add_residual
:
if
not
self
.
add_residual
:
return
out
return
self
.
dropout
(
out
)
if
residual
is
None
:
if
residual
is
None
:
residual
=
x
residual
=
x
return
residual
+
self
.
dropout
(
out
)
return
residual
+
self
.
dropout
(
out
)
...
@@ -246,7 +247,7 @@ class BaseTransformerLayer(BaseModule):
...
@@ -246,7 +247,7 @@ class BaseTransformerLayer(BaseModule):
ffn_num_fcs
=
2
,
ffn_num_fcs
=
2
,
init_cfg
=
None
):
init_cfg
=
None
):
super
(
BaseTransformerLayer
,
self
).
__init__
()
super
(
BaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
assert
set
(
operation_order
)
&
set
(
assert
set
(
operation_order
)
&
set
(
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
set
(
operation_order
),
f
'The operation_order of'
\
set
(
operation_order
),
f
'The operation_order of'
\
...
@@ -338,6 +339,12 @@ class BaseTransformerLayer(BaseModule):
...
@@ -338,6 +339,12 @@ class BaseTransformerLayer(BaseModule):
inp_residual
=
query
inp_residual
=
query
if
attn_masks
is
None
:
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
...
@@ -407,7 +414,7 @@ class TransformerLayerSequence(BaseModule):
...
@@ -407,7 +414,7 @@ class TransformerLayerSequence(BaseModule):
"""
"""
def
__init__
(
self
,
transformerlayers
=
None
,
num_layers
=
None
,
init_cfg
=
None
):
def
__init__
(
self
,
transformerlayers
=
None
,
num_layers
=
None
,
init_cfg
=
None
):
super
(
TransformerLayerSequence
,
self
).
__init__
()
super
(
TransformerLayerSequence
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
transformerlayers
,
ConfigDict
):
if
isinstance
(
transformerlayers
,
ConfigDict
):
transformerlayers
=
[
transformerlayers
=
[
copy
.
deepcopy
(
transformerlayers
)
for
_
in
range
(
num_layers
)
copy
.
deepcopy
(
transformerlayers
)
for
_
in
range
(
num_layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment