Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
45fa3e44
Unverified
Commit
45fa3e44
authored
May 18, 2022
by
Zaida Zhou
Committed by
GitHub
May 18, 2022
Browse files
Add pyupgrade pre-commit hook (#1937)
* add pyupgrade * add options for pyupgrade * minor refinement
parent
c561264d
Changes
110
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
37 additions
and
36 deletions
+37
-36
.dev_scripts/visualize_lr.py
.dev_scripts/visualize_lr.py
+3
-3
.pre-commit-config.yaml
.pre-commit-config.yaml
+5
-1
docs/en/conf.py
docs/en/conf.py
+1
-1
docs/zh_cn/conf.py
docs/zh_cn/conf.py
+1
-1
examples/train.py
examples/train.py
+1
-1
mmcv/cnn/alexnet.py
mmcv/cnn/alexnet.py
+1
-1
mmcv/cnn/bricks/activation.py
mmcv/cnn/bricks/activation.py
+1
-1
mmcv/cnn/bricks/context_block.py
mmcv/cnn/bricks/context_block.py
+1
-1
mmcv/cnn/bricks/conv_module.py
mmcv/cnn/bricks/conv_module.py
+2
-2
mmcv/cnn/bricks/conv_ws.py
mmcv/cnn/bricks/conv_ws.py
+1
-1
mmcv/cnn/bricks/depthwise_separable_conv_module.py
mmcv/cnn/bricks/depthwise_separable_conv_module.py
+1
-1
mmcv/cnn/bricks/drop.py
mmcv/cnn/bricks/drop.py
+1
-1
mmcv/cnn/bricks/generalized_attention.py
mmcv/cnn/bricks/generalized_attention.py
+1
-1
mmcv/cnn/bricks/hsigmoid.py
mmcv/cnn/bricks/hsigmoid.py
+1
-1
mmcv/cnn/bricks/hswish.py
mmcv/cnn/bricks/hswish.py
+1
-1
mmcv/cnn/bricks/non_local.py
mmcv/cnn/bricks/non_local.py
+4
-7
mmcv/cnn/bricks/scale.py
mmcv/cnn/bricks/scale.py
+1
-1
mmcv/cnn/bricks/swish.py
mmcv/cnn/bricks/swish.py
+1
-1
mmcv/cnn/bricks/transformer.py
mmcv/cnn/bricks/transformer.py
+8
-8
mmcv/cnn/bricks/upsample.py
mmcv/cnn/bricks/upsample.py
+1
-1
No files found.
.dev_scripts/visualize_lr.py
View file @
45fa3e44
...
@@ -42,7 +42,7 @@ def parse_args():
...
@@ -42,7 +42,7 @@ def parse_args():
class
SimpleModel
(
nn
.
Module
):
class
SimpleModel
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
SimpleModel
,
self
).
__init__
()
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
1
,
1
,
1
)
self
.
conv
=
nn
.
Conv2d
(
1
,
1
,
1
)
def
train_step
(
self
,
*
args
,
**
kwargs
):
def
train_step
(
self
,
*
args
,
**
kwargs
):
...
@@ -159,13 +159,13 @@ def run(cfg, logger):
...
@@ -159,13 +159,13 @@ def run(cfg, logger):
def
plot_lr_curve
(
json_file
,
cfg
):
def
plot_lr_curve
(
json_file
,
cfg
):
data_dict
=
dict
(
LearningRate
=
[],
Momentum
=
[])
data_dict
=
dict
(
LearningRate
=
[],
Momentum
=
[])
assert
os
.
path
.
isfile
(
json_file
)
assert
os
.
path
.
isfile
(
json_file
)
with
open
(
json_file
,
'r'
)
as
f
:
with
open
(
json_file
)
as
f
:
for
line
in
f
:
for
line
in
f
:
log
=
json
.
loads
(
line
.
strip
())
log
=
json
.
loads
(
line
.
strip
())
data_dict
[
'LearningRate'
].
append
(
log
[
'lr'
])
data_dict
[
'LearningRate'
].
append
(
log
[
'lr'
])
data_dict
[
'Momentum'
].
append
(
log
[
'momentum'
])
data_dict
[
'Momentum'
].
append
(
log
[
'momentum'
])
wind_w
,
wind_h
=
[
int
(
size
)
for
size
in
cfg
.
window_size
.
split
(
'*'
)
]
wind_w
,
wind_h
=
(
int
(
size
)
for
size
in
cfg
.
window_size
.
split
(
'*'
)
)
# if legend is None, use {filename}_{key} as legend
# if legend is None, use {filename}_{key} as legend
fig
,
axes
=
plt
.
subplots
(
2
,
1
,
figsize
=
(
wind_w
,
wind_h
))
fig
,
axes
=
plt
.
subplots
(
2
,
1
,
figsize
=
(
wind_w
,
wind_h
))
plt
.
subplots_adjust
(
hspace
=
0.5
)
plt
.
subplots_adjust
(
hspace
=
0.5
)
...
...
.pre-commit-config.yaml
View file @
45fa3e44
...
@@ -43,7 +43,11 @@ repos:
...
@@ -43,7 +43,11 @@ repos:
hooks
:
hooks
:
-
id
:
docformatter
-
id
:
docformatter
args
:
[
"
--in-place"
,
"
--wrap-descriptions"
,
"
79"
]
args
:
[
"
--in-place"
,
"
--wrap-descriptions"
,
"
79"
]
-
repo
:
https://github.com/asottile/pyupgrade
rev
:
v2.32.1
hooks
:
-
id
:
pyupgrade
args
:
[
"
--py36-plus"
]
-
repo
:
https://github.com/open-mmlab/pre-commit-hooks
-
repo
:
https://github.com/open-mmlab/pre-commit-hooks
rev
:
v0.2.0
# Use the ref you want to point at
rev
:
v0.2.0
# Use the ref you want to point at
hooks
:
hooks
:
...
...
docs/en/conf.py
View file @
45fa3e44
...
@@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
...
@@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
version_file
=
'../../mmcv/version.py'
version_file
=
'../../mmcv/version.py'
with
open
(
version_file
,
'r'
)
as
f
:
with
open
(
version_file
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
__version__
=
locals
()[
'__version__'
]
__version__
=
locals
()[
'__version__'
]
...
...
docs/zh_cn/conf.py
View file @
45fa3e44
...
@@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
...
@@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
version_file
=
'../../mmcv/version.py'
version_file
=
'../../mmcv/version.py'
with
open
(
version_file
,
'r'
)
as
f
:
with
open
(
version_file
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
__version__
=
locals
()[
'__version__'
]
__version__
=
locals
()[
'__version__'
]
...
...
examples/train.py
View file @
45fa3e44
...
@@ -14,7 +14,7 @@ from mmcv.utils import get_logger
...
@@ -14,7 +14,7 @@ from mmcv.utils import get_logger
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
...
...
mmcv/cnn/alexnet.py
View file @
45fa3e44
...
@@ -12,7 +12,7 @@ class AlexNet(nn.Module):
...
@@ -12,7 +12,7 @@ class AlexNet(nn.Module):
"""
"""
def
__init__
(
self
,
num_classes
=-
1
):
def
__init__
(
self
,
num_classes
=-
1
):
super
(
AlexNet
,
self
).
__init__
()
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
features
=
nn
.
Sequential
(
self
.
features
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
2
),
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
2
),
...
...
mmcv/cnn/bricks/activation.py
View file @
45fa3e44
...
@@ -29,7 +29,7 @@ class Clamp(nn.Module):
...
@@ -29,7 +29,7 @@ class Clamp(nn.Module):
"""
"""
def
__init__
(
self
,
min
=-
1.
,
max
=
1.
):
def
__init__
(
self
,
min
=-
1.
,
max
=
1.
):
super
(
Clamp
,
self
).
__init__
()
super
().
__init__
()
self
.
min
=
min
self
.
min
=
min
self
.
max
=
max
self
.
max
=
max
...
...
mmcv/cnn/bricks/context_block.py
View file @
45fa3e44
...
@@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
...
@@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
ratio
,
ratio
,
pooling_type
=
'att'
,
pooling_type
=
'att'
,
fusion_types
=
(
'channel_add'
,
)):
fusion_types
=
(
'channel_add'
,
)):
super
(
ContextBlock
,
self
).
__init__
()
super
().
__init__
()
assert
pooling_type
in
[
'avg'
,
'att'
]
assert
pooling_type
in
[
'avg'
,
'att'
]
assert
isinstance
(
fusion_types
,
(
list
,
tuple
))
assert
isinstance
(
fusion_types
,
(
list
,
tuple
))
valid_fusion_types
=
[
'channel_add'
,
'channel_mul'
]
valid_fusion_types
=
[
'channel_add'
,
'channel_mul'
]
...
...
mmcv/cnn/bricks/conv_module.py
View file @
45fa3e44
...
@@ -83,7 +83,7 @@ class ConvModule(nn.Module):
...
@@ -83,7 +83,7 @@ class ConvModule(nn.Module):
with_spectral_norm
=
False
,
with_spectral_norm
=
False
,
padding_mode
=
'zeros'
,
padding_mode
=
'zeros'
,
order
=
(
'conv'
,
'norm'
,
'act'
)):
order
=
(
'conv'
,
'norm'
,
'act'
)):
super
(
ConvModule
,
self
).
__init__
()
super
().
__init__
()
assert
conv_cfg
is
None
or
isinstance
(
conv_cfg
,
dict
)
assert
conv_cfg
is
None
or
isinstance
(
conv_cfg
,
dict
)
assert
norm_cfg
is
None
or
isinstance
(
norm_cfg
,
dict
)
assert
norm_cfg
is
None
or
isinstance
(
norm_cfg
,
dict
)
assert
act_cfg
is
None
or
isinstance
(
act_cfg
,
dict
)
assert
act_cfg
is
None
or
isinstance
(
act_cfg
,
dict
)
...
@@ -96,7 +96,7 @@ class ConvModule(nn.Module):
...
@@ -96,7 +96,7 @@ class ConvModule(nn.Module):
self
.
with_explicit_padding
=
padding_mode
not
in
official_padding_mode
self
.
with_explicit_padding
=
padding_mode
not
in
official_padding_mode
self
.
order
=
order
self
.
order
=
order
assert
isinstance
(
self
.
order
,
tuple
)
and
len
(
self
.
order
)
==
3
assert
isinstance
(
self
.
order
,
tuple
)
and
len
(
self
.
order
)
==
3
assert
set
(
order
)
==
set
([
'conv'
,
'norm'
,
'act'
])
assert
set
(
order
)
==
{
'conv'
,
'norm'
,
'act'
}
self
.
with_norm
=
norm_cfg
is
not
None
self
.
with_norm
=
norm_cfg
is
not
None
self
.
with_activation
=
act_cfg
is
not
None
self
.
with_activation
=
act_cfg
is
not
None
...
...
mmcv/cnn/bricks/conv_ws.py
View file @
45fa3e44
...
@@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
...
@@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
groups
=
1
,
groups
=
1
,
bias
=
True
,
bias
=
True
,
eps
=
1e-5
):
eps
=
1e-5
):
super
(
ConvWS2d
,
self
).
__init__
(
super
().
__init__
(
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
...
...
mmcv/cnn/bricks/depthwise_separable_conv_module.py
View file @
45fa3e44
...
@@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
...
@@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
pw_norm_cfg
=
'default'
,
pw_norm_cfg
=
'default'
,
pw_act_cfg
=
'default'
,
pw_act_cfg
=
'default'
,
**
kwargs
):
**
kwargs
):
super
(
DepthwiseSeparableConvModule
,
self
).
__init__
()
super
().
__init__
()
assert
'groups'
not
in
kwargs
,
'groups should not be specified'
assert
'groups'
not
in
kwargs
,
'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# if norm/activation config of depthwise/pointwise ConvModule is not
...
...
mmcv/cnn/bricks/drop.py
View file @
45fa3e44
...
@@ -37,7 +37,7 @@ class DropPath(nn.Module):
...
@@ -37,7 +37,7 @@ class DropPath(nn.Module):
"""
"""
def
__init__
(
self
,
drop_prob
=
0.1
):
def
__init__
(
self
,
drop_prob
=
0.1
):
super
(
DropPath
,
self
).
__init__
()
super
().
__init__
()
self
.
drop_prob
=
drop_prob
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
mmcv/cnn/bricks/generalized_attention.py
View file @
45fa3e44
...
@@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
...
@@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
q_stride
=
1
,
q_stride
=
1
,
attention_type
=
'1111'
):
attention_type
=
'1111'
):
super
(
GeneralizedAttention
,
self
).
__init__
()
super
().
__init__
()
# hard range means local range for non-local operation
# hard range means local range for non-local operation
self
.
position_embedding_dim
=
(
self
.
position_embedding_dim
=
(
...
...
mmcv/cnn/bricks/hsigmoid.py
View file @
45fa3e44
...
@@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
...
@@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
"""
"""
def
__init__
(
self
,
bias
=
3.0
,
divisor
=
6.0
,
min_value
=
0.0
,
max_value
=
1.0
):
def
__init__
(
self
,
bias
=
3.0
,
divisor
=
6.0
,
min_value
=
0.0
,
max_value
=
1.0
):
super
(
HSigmoid
,
self
).
__init__
()
super
().
__init__
()
warnings
.
warn
(
warnings
.
warn
(
'In MMCV v1.4.4, we modified the default value of args to align '
'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: '
'with PyTorch official. Previous Implementation: '
...
...
mmcv/cnn/bricks/hswish.py
View file @
45fa3e44
...
@@ -22,7 +22,7 @@ class HSwish(nn.Module):
...
@@ -22,7 +22,7 @@ class HSwish(nn.Module):
"""
"""
def
__init__
(
self
,
inplace
=
False
):
def
__init__
(
self
,
inplace
=
False
):
super
(
HSwish
,
self
).
__init__
()
super
().
__init__
()
self
.
act
=
nn
.
ReLU6
(
inplace
)
self
.
act
=
nn
.
ReLU6
(
inplace
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
mmcv/cnn/bricks/non_local.py
View file @
45fa3e44
...
@@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
...
@@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
norm_cfg
=
None
,
norm_cfg
=
None
,
mode
=
'embedded_gaussian'
,
mode
=
'embedded_gaussian'
,
**
kwargs
):
**
kwargs
):
super
(
_NonLocalNd
,
self
).
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
use_scale
=
use_scale
self
.
use_scale
=
use_scale
...
@@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
...
@@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
sub_sample
=
False
,
sub_sample
=
False
,
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
**
kwargs
):
**
kwargs
):
super
(
NonLocal1d
,
self
).
__init__
(
super
().
__init__
(
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
self
.
sub_sample
=
sub_sample
self
.
sub_sample
=
sub_sample
...
@@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
...
@@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
sub_sample
=
False
,
sub_sample
=
False
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
conv_cfg
=
dict
(
type
=
'Conv2d'
),
**
kwargs
):
**
kwargs
):
super
(
NonLocal2d
,
self
).
__init__
(
super
().
__init__
(
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
self
.
sub_sample
=
sub_sample
self
.
sub_sample
=
sub_sample
...
@@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
...
@@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
sub_sample
=
False
,
sub_sample
=
False
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
conv_cfg
=
dict
(
type
=
'Conv3d'
),
**
kwargs
):
**
kwargs
):
super
(
NonLocal3d
,
self
).
__init__
(
super
().
__init__
(
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
in_channels
,
conv_cfg
=
conv_cfg
,
**
kwargs
)
self
.
sub_sample
=
sub_sample
self
.
sub_sample
=
sub_sample
if
sub_sample
:
if
sub_sample
:
...
...
mmcv/cnn/bricks/scale.py
View file @
45fa3e44
...
@@ -14,7 +14,7 @@ class Scale(nn.Module):
...
@@ -14,7 +14,7 @@ class Scale(nn.Module):
"""
"""
def
__init__
(
self
,
scale
=
1.0
):
def
__init__
(
self
,
scale
=
1.0
):
super
(
Scale
,
self
).
__init__
()
super
().
__init__
()
self
.
scale
=
nn
.
Parameter
(
torch
.
tensor
(
scale
,
dtype
=
torch
.
float
))
self
.
scale
=
nn
.
Parameter
(
torch
.
tensor
(
scale
,
dtype
=
torch
.
float
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
mmcv/cnn/bricks/swish.py
View file @
45fa3e44
...
@@ -19,7 +19,7 @@ class Swish(nn.Module):
...
@@ -19,7 +19,7 @@ class Swish(nn.Module):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Swish
,
self
).
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
return
x
*
torch
.
sigmoid
(
x
)
mmcv/cnn/bricks/transformer.py
View file @
45fa3e44
...
@@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
...
@@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
"""
"""
def
__init__
(
self
,
kernel_size
=
1
,
stride
=
1
,
dilation
=
1
,
padding
=
'corner'
):
def
__init__
(
self
,
kernel_size
=
1
,
stride
=
1
,
dilation
=
1
,
padding
=
'corner'
):
super
(
AdaptivePadding
,
self
).
__init__
()
super
().
__init__
()
assert
padding
in
(
'same'
,
'corner'
)
assert
padding
in
(
'same'
,
'corner'
)
kernel_size
=
to_2tuple
(
kernel_size
)
kernel_size
=
to_2tuple
(
kernel_size
)
...
@@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
...
@@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
norm_cfg
=
None
,
norm_cfg
=
None
,
input_size
=
None
,
input_size
=
None
,
init_cfg
=
None
):
init_cfg
=
None
):
super
(
PatchEmbed
,
self
).
__init__
(
init_cfg
=
init_cfg
)
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
if
stride
is
None
:
if
stride
is
None
:
...
@@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
...
@@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
init_cfg
=
None
,
init_cfg
=
None
,
batch_first
=
False
,
batch_first
=
False
,
**
kwargs
):
**
kwargs
):
super
(
MultiheadAttention
,
self
).
__init__
(
init_cfg
)
super
().
__init__
(
init_cfg
)
if
'dropout'
in
kwargs
:
if
'dropout'
in
kwargs
:
warnings
.
warn
(
warnings
.
warn
(
'The arguments `dropout` in MultiheadAttention '
'The arguments `dropout` in MultiheadAttention '
...
@@ -590,7 +590,7 @@ class FFN(BaseModule):
...
@@ -590,7 +590,7 @@ class FFN(BaseModule):
add_identity
=
True
,
add_identity
=
True
,
init_cfg
=
None
,
init_cfg
=
None
,
**
kwargs
):
**
kwargs
):
super
(
FFN
,
self
).
__init__
(
init_cfg
)
super
().
__init__
(
init_cfg
)
assert
num_fcs
>=
2
,
'num_fcs should be no less '
\
assert
num_fcs
>=
2
,
'num_fcs should be no less '
\
f
'than 2. got
{
num_fcs
}
.'
f
'than 2. got
{
num_fcs
}
.'
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
...
@@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
...
@@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
f
'to a dict named `ffn_cfgs`. '
,
DeprecationWarning
)
f
'to a dict named `ffn_cfgs`. '
,
DeprecationWarning
)
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
super
(
BaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
super
().
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
self
.
batch_first
=
batch_first
assert
set
(
operation_order
)
&
set
(
assert
set
(
operation_order
)
&
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
}
==
\
set
(
operation_order
),
f
'The operation_order of'
\
set
(
operation_order
),
f
'The operation_order of'
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'contains all four operation type '
\
f
'contains all four operation type '
\
...
@@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
...
@@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
"""
"""
def
__init__
(
self
,
transformerlayers
=
None
,
num_layers
=
None
,
init_cfg
=
None
):
def
__init__
(
self
,
transformerlayers
=
None
,
num_layers
=
None
,
init_cfg
=
None
):
super
(
TransformerLayerSequence
,
self
).
__init__
(
init_cfg
)
super
().
__init__
(
init_cfg
)
if
isinstance
(
transformerlayers
,
dict
):
if
isinstance
(
transformerlayers
,
dict
):
transformerlayers
=
[
transformerlayers
=
[
copy
.
deepcopy
(
transformerlayers
)
for
_
in
range
(
num_layers
)
copy
.
deepcopy
(
transformerlayers
)
for
_
in
range
(
num_layers
)
...
...
mmcv/cnn/bricks/upsample.py
View file @
45fa3e44
...
@@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
...
@@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale_factor
,
def
__init__
(
self
,
in_channels
,
out_channels
,
scale_factor
,
upsample_kernel
):
upsample_kernel
):
super
(
PixelShufflePack
,
self
).
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
scale_factor
=
scale_factor
self
.
scale_factor
=
scale_factor
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment