Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
45fa3e44
Unverified
Commit
45fa3e44
authored
May 18, 2022
by
Zaida Zhou
Committed by
GitHub
May 18, 2022
Browse files
Add pyupgrade pre-commit hook (#1937)
* add pyupgrade * add options for pyupgrade * minor refinement
parent
c561264d
Changes
110
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
54 additions
and
60 deletions
+54
-60
mmcv/cnn/resnet.py
mmcv/cnn/resnet.py
+4
-4
mmcv/cnn/utils/flops_counter.py
mmcv/cnn/utils/flops_counter.py
+2
-2
mmcv/cnn/utils/weight_init.py
mmcv/cnn/utils/weight_init.py
+2
-2
mmcv/cnn/vgg.py
mmcv/cnn/vgg.py
+2
-2
mmcv/device/mlu/data_parallel.py
mmcv/device/mlu/data_parallel.py
+1
-1
mmcv/fileio/file_client.py
mmcv/fileio/file_client.py
+16
-16
mmcv/fileio/handlers/pickle_handler.py
mmcv/fileio/handlers/pickle_handler.py
+2
-4
mmcv/image/geometric.py
mmcv/image/geometric.py
+1
-1
mmcv/onnx/onnx_utils/symbolic_helper.py
mmcv/onnx/onnx_utils/symbolic_helper.py
+1
-1
mmcv/ops/border_align.py
mmcv/ops/border_align.py
+1
-1
mmcv/ops/box_iou_rotated.py
mmcv/ops/box_iou_rotated.py
+1
-1
mmcv/ops/carafe.py
mmcv/ops/carafe.py
+3
-3
mmcv/ops/corner_pool.py
mmcv/ops/corner_pool.py
+1
-1
mmcv/ops/deform_conv.py
mmcv/ops/deform_conv.py
+2
-2
mmcv/ops/deform_roi_pool.py
mmcv/ops/deform_roi_pool.py
+3
-5
mmcv/ops/focal_loss.py
mmcv/ops/focal_loss.py
+2
-2
mmcv/ops/fused_bias_leakyrelu.py
mmcv/ops/fused_bias_leakyrelu.py
+1
-1
mmcv/ops/masked_conv.py
mmcv/ops/masked_conv.py
+3
-4
mmcv/ops/merge_cells.py
mmcv/ops/merge_cells.py
+3
-4
mmcv/ops/modulated_deform_conv.py
mmcv/ops/modulated_deform_conv.py
+3
-3
No files found.
mmcv/cnn/resnet.py
View file @
45fa3e44
...
@@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
...
@@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
super
().
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
...
@@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
...
@@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
"""
super
(
Bottleneck
,
self
).
__init__
()
super
().
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
assert
style
in
[
'pytorch'
,
'caffe'
]
if
style
==
'pytorch'
:
if
style
==
'pytorch'
:
conv1_stride
=
1
conv1_stride
=
1
...
@@ -218,7 +218,7 @@ class ResNet(nn.Module):
...
@@ -218,7 +218,7 @@ class ResNet(nn.Module):
bn_eval
=
True
,
bn_eval
=
True
,
bn_frozen
=
False
,
bn_frozen
=
False
,
with_cp
=
False
):
with_cp
=
False
):
super
(
ResNet
,
self
).
__init__
()
super
().
__init__
()
if
depth
not
in
self
.
arch_settings
:
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for resnet'
)
raise
KeyError
(
f
'invalid depth
{
depth
}
for resnet'
)
assert
num_stages
>=
1
and
num_stages
<=
4
assert
num_stages
>=
1
and
num_stages
<=
4
...
@@ -293,7 +293,7 @@ class ResNet(nn.Module):
...
@@ -293,7 +293,7 @@ class ResNet(nn.Module):
return
tuple
(
outs
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
super
().
train
(
mode
)
if
self
.
bn_eval
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
...
...
mmcv/cnn/utils/flops_counter.py
View file @
45fa3e44
...
@@ -277,10 +277,10 @@ def print_model_with_flops(model,
...
@@ -277,10 +277,10 @@ def print_model_with_flops(model,
return
', '
.
join
([
return
', '
.
join
([
params_to_string
(
params_to_string
(
accumulated_num_params
,
units
=
'M'
,
precision
=
precision
),
accumulated_num_params
,
units
=
'M'
,
precision
=
precision
),
'{
:.3%} Params'
.
format
(
accumulated_num_params
/
total_params
)
,
f
'
{
accumulated_num_params
/
total_params
:.
3
%
}
Params'
,
flops_to_string
(
flops_to_string
(
accumulated_flops_cost
,
units
=
units
,
precision
=
precision
),
accumulated_flops_cost
,
units
=
units
,
precision
=
precision
),
'{
:.3%} FLOPs'
.
format
(
accumulated_flops_cost
/
total_flops
)
,
f
'
{
accumulated_flops_cost
/
total_flops
:.
3
%
}
FLOPs'
,
self
.
original_extra_repr
()
self
.
original_extra_repr
()
])
])
...
...
mmcv/cnn/utils/weight_init.py
View file @
45fa3e44
...
@@ -129,7 +129,7 @@ def _get_bases_name(m):
...
@@ -129,7 +129,7 @@ def _get_bases_name(m):
return
[
b
.
__name__
for
b
in
m
.
__class__
.
__bases__
]
return
[
b
.
__name__
for
b
in
m
.
__class__
.
__bases__
]
class
BaseInit
(
object
)
:
class
BaseInit
:
def
__init__
(
self
,
*
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
def
__init__
(
self
,
*
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
self
.
wholemodule
=
False
self
.
wholemodule
=
False
...
@@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
...
@@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
class
PretrainedInit
(
object
)
:
class
PretrainedInit
:
"""Initialize module by loading a pretrained model.
"""Initialize module by loading a pretrained model.
Args:
Args:
...
...
mmcv/cnn/vgg.py
View file @
45fa3e44
...
@@ -70,7 +70,7 @@ class VGG(nn.Module):
...
@@ -70,7 +70,7 @@ class VGG(nn.Module):
bn_frozen
=
False
,
bn_frozen
=
False
,
ceil_mode
=
False
,
ceil_mode
=
False
,
with_last_pool
=
True
):
with_last_pool
=
True
):
super
(
VGG
,
self
).
__init__
()
super
().
__init__
()
if
depth
not
in
self
.
arch_settings
:
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for vgg'
)
raise
KeyError
(
f
'invalid depth
{
depth
}
for vgg'
)
assert
num_stages
>=
1
and
num_stages
<=
5
assert
num_stages
>=
1
and
num_stages
<=
5
...
@@ -157,7 +157,7 @@ class VGG(nn.Module):
...
@@ -157,7 +157,7 @@ class VGG(nn.Module):
return
tuple
(
outs
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
super
(
VGG
,
self
).
train
(
mode
)
super
().
train
(
mode
)
if
self
.
bn_eval
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
...
...
mmcv/device/mlu/data_parallel.py
View file @
45fa3e44
...
@@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
...
@@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
"""
"""
def
__init__
(
self
,
*
args
,
dim
=
0
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
dim
=
0
,
**
kwargs
):
super
(
MLUDataParallel
,
self
).
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
super
().
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
self
.
device_ids
=
[
0
]
self
.
device_ids
=
[
0
]
self
.
src_device_obj
=
torch
.
device
(
'mlu:0'
)
self
.
src_device_obj
=
torch
.
device
(
'mlu:0'
)
...
...
mmcv/fileio/file_client.py
View file @
45fa3e44
...
@@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'delete'
):
if
not
has_method
(
self
.
_client
,
'delete'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
'the `delete` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if
not
(
has_method
(
self
.
_client
,
'contains'
)
if
not
(
has_method
(
self
.
_client
,
'contains'
)
and
has_method
(
self
.
_client
,
'isdir'
)):
and
has_method
(
self
.
_client
,
'isdir'
)):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'
)
)
'version or dev branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'isdir'
):
if
not
has_method
(
self
.
_client
,
'isdir'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
'the `isdir` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'contains'
):
if
not
has_method
(
self
.
_client
,
'contains'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'the `contains` method, please use a higher version or '
'dev branch instead.'
)
)
'dev branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'list'
):
if
not
has_method
(
self
.
_client
,
'list'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
'the `list` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
dir_path
=
self
.
_map_path
(
dir_path
)
dir_path
=
self
.
_map_path
(
dir_path
)
dir_path
=
self
.
_format_path
(
dir_path
)
dir_path
=
self
.
_format_path
(
dir_path
)
...
@@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
...
@@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
Returns:
str: Expected text reading from ``filepath``.
str: Expected text reading from ``filepath``.
"""
"""
with
open
(
filepath
,
'r'
,
encoding
=
encoding
)
as
f
:
with
open
(
filepath
,
encoding
=
encoding
)
as
f
:
value_buf
=
f
.
read
()
value_buf
=
f
.
read
()
return
value_buf
return
value_buf
...
...
mmcv/fileio/handlers/pickle_handler.py
View file @
45fa3e44
...
@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
...
@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return
pickle
.
load
(
file
,
**
kwargs
)
return
pickle
.
load
(
file
,
**
kwargs
)
def
load_from_path
(
self
,
filepath
,
**
kwargs
):
def
load_from_path
(
self
,
filepath
,
**
kwargs
):
return
super
(
PickleHandler
,
self
).
load_from_path
(
return
super
().
load_from_path
(
filepath
,
mode
=
'rb'
,
**
kwargs
)
filepath
,
mode
=
'rb'
,
**
kwargs
)
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
kwargs
.
setdefault
(
'protocol'
,
2
)
kwargs
.
setdefault
(
'protocol'
,
2
)
...
@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
...
@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle
.
dump
(
obj
,
file
,
**
kwargs
)
pickle
.
dump
(
obj
,
file
,
**
kwargs
)
def
dump_to_path
(
self
,
obj
,
filepath
,
**
kwargs
):
def
dump_to_path
(
self
,
obj
,
filepath
,
**
kwargs
):
super
(
PickleHandler
,
self
).
dump_to_path
(
super
().
dump_to_path
(
obj
,
filepath
,
mode
=
'wb'
,
**
kwargs
)
obj
,
filepath
,
mode
=
'wb'
,
**
kwargs
)
mmcv/image/geometric.py
View file @
45fa3e44
...
@@ -157,7 +157,7 @@ def imresize_to_multiple(img,
...
@@ -157,7 +157,7 @@ def imresize_to_multiple(img,
size
=
_scale_size
((
w
,
h
),
scale_factor
)
size
=
_scale_size
((
w
,
h
),
scale_factor
)
divisor
=
to_2tuple
(
divisor
)
divisor
=
to_2tuple
(
divisor
)
size
=
tuple
(
[
int
(
np
.
ceil
(
s
/
d
))
*
d
for
s
,
d
in
zip
(
size
,
divisor
)
]
)
size
=
tuple
(
int
(
np
.
ceil
(
s
/
d
))
*
d
for
s
,
d
in
zip
(
size
,
divisor
))
resized_img
,
w_scale
,
h_scale
=
imresize
(
resized_img
,
w_scale
,
h_scale
=
imresize
(
img
,
img
,
size
,
size
,
...
...
mmcv/onnx/onnx_utils/symbolic_helper.py
View file @
45fa3e44
...
@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
...
@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise
RuntimeError
(
raise
RuntimeError
(
"ONNX symbolic doesn't know to interpret ListConstruct node"
)
"ONNX symbolic doesn't know to interpret ListConstruct node"
)
raise
RuntimeError
(
'Unexpected node type: {
}'
.
format
(
value
.
node
().
kind
()
)
)
raise
RuntimeError
(
f
'Unexpected node type:
{
value
.
node
().
kind
()
}
'
)
def
_maybe_get_const
(
value
,
desc
):
def
_maybe_get_const
(
value
,
desc
):
...
...
mmcv/ops/border_align.py
View file @
45fa3e44
...
@@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
...
@@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
"""
"""
def
__init__
(
self
,
pool_size
):
def
__init__
(
self
,
pool_size
):
super
(
BorderAlign
,
self
).
__init__
()
super
().
__init__
()
self
.
pool_size
=
pool_size
self
.
pool_size
=
pool_size
def
forward
(
self
,
input
,
boxes
):
def
forward
(
self
,
input
,
boxes
):
...
...
mmcv/ops/box_iou_rotated.py
View file @
45fa3e44
...
@@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
...
@@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
if
aligned
:
if
aligned
:
ious
=
bboxes1
.
new_zeros
(
rows
)
ious
=
bboxes1
.
new_zeros
(
rows
)
else
:
else
:
ious
=
bboxes1
.
new_zeros
(
(
rows
*
cols
)
)
ious
=
bboxes1
.
new_zeros
(
rows
*
cols
)
if
not
clockwise
:
if
not
clockwise
:
flip_mat
=
bboxes1
.
new_ones
(
bboxes1
.
shape
[
-
1
])
flip_mat
=
bboxes1
.
new_ones
(
bboxes1
.
shape
[
-
1
])
flip_mat
[
-
1
]
=
-
1
flip_mat
[
-
1
]
=
-
1
...
...
mmcv/ops/carafe.py
View file @
45fa3e44
...
@@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
...
@@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
class
CARAFENaive
(
Module
):
class
CARAFENaive
(
Module
):
def
__init__
(
self
,
kernel_size
,
group_size
,
scale_factor
):
def
__init__
(
self
,
kernel_size
,
group_size
,
scale_factor
):
super
(
CARAFENaive
,
self
).
__init__
()
super
().
__init__
()
assert
isinstance
(
kernel_size
,
int
)
and
isinstance
(
assert
isinstance
(
kernel_size
,
int
)
and
isinstance
(
group_size
,
int
)
and
isinstance
(
scale_factor
,
int
)
group_size
,
int
)
and
isinstance
(
scale_factor
,
int
)
...
@@ -195,7 +195,7 @@ class CARAFE(Module):
...
@@ -195,7 +195,7 @@ class CARAFE(Module):
"""
"""
def
__init__
(
self
,
kernel_size
,
group_size
,
scale_factor
):
def
__init__
(
self
,
kernel_size
,
group_size
,
scale_factor
):
super
(
CARAFE
,
self
).
__init__
()
super
().
__init__
()
assert
isinstance
(
kernel_size
,
int
)
and
isinstance
(
assert
isinstance
(
kernel_size
,
int
)
and
isinstance
(
group_size
,
int
)
and
isinstance
(
scale_factor
,
int
)
group_size
,
int
)
and
isinstance
(
scale_factor
,
int
)
...
@@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
...
@@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
encoder_kernel
=
3
,
encoder_kernel
=
3
,
encoder_dilation
=
1
,
encoder_dilation
=
1
,
compressed_channels
=
64
):
compressed_channels
=
64
):
super
(
CARAFEPack
,
self
).
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
scale_factor
=
scale_factor
self
.
scale_factor
=
scale_factor
self
.
up_kernel
=
up_kernel
self
.
up_kernel
=
up_kernel
...
...
mmcv/ops/corner_pool.py
View file @
45fa3e44
...
@@ -125,7 +125,7 @@ class CornerPool(nn.Module):
...
@@ -125,7 +125,7 @@ class CornerPool(nn.Module):
}
}
def
__init__
(
self
,
mode
):
def
__init__
(
self
,
mode
):
super
(
CornerPool
,
self
).
__init__
()
super
().
__init__
()
assert
mode
in
self
.
pool_functions
assert
mode
in
self
.
pool_functions
self
.
mode
=
mode
self
.
mode
=
mode
self
.
corner_pool
=
self
.
pool_functions
[
mode
]
self
.
corner_pool
=
self
.
pool_functions
[
mode
]
...
...
mmcv/ops/deform_conv.py
View file @
45fa3e44
...
@@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
...
@@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
deform_groups
:
int
=
1
,
deform_groups
:
int
=
1
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
im2col_step
:
int
=
32
)
->
None
:
im2col_step
:
int
=
32
)
->
None
:
super
(
DeformConv2d
,
self
).
__init__
()
super
().
__init__
()
assert
not
bias
,
\
assert
not
bias
,
\
f
'bias=
{
bias
}
is not supported in DeformConv2d.'
f
'bias=
{
bias
}
is not supported in DeformConv2d.'
...
@@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
...
@@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
_version
=
2
_version
=
2
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DeformConv2dPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
in_channels
,
self
.
deform_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
self
.
deform_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
...
...
mmcv/ops/deform_roi_pool.py
View file @
45fa3e44
...
@@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
...
@@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
spatial_scale
=
1.0
,
spatial_scale
=
1.0
,
sampling_ratio
=
0
,
sampling_ratio
=
0
,
gamma
=
0.1
):
gamma
=
0.1
):
super
(
DeformRoIPool
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
_pair
(
output_size
)
self
.
output_size
=
_pair
(
output_size
)
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
sampling_ratio
=
int
(
sampling_ratio
)
self
.
sampling_ratio
=
int
(
sampling_ratio
)
...
@@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
...
@@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
spatial_scale
=
1.0
,
spatial_scale
=
1.0
,
sampling_ratio
=
0
,
sampling_ratio
=
0
,
gamma
=
0.1
):
gamma
=
0.1
):
super
(
DeformRoIPoolPack
,
self
).
__init__
(
output_size
,
spatial_scale
,
super
().
__init__
(
output_size
,
spatial_scale
,
sampling_ratio
,
gamma
)
sampling_ratio
,
gamma
)
self
.
output_channels
=
output_channels
self
.
output_channels
=
output_channels
self
.
deform_fc_channels
=
deform_fc_channels
self
.
deform_fc_channels
=
deform_fc_channels
...
@@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
...
@@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
spatial_scale
=
1.0
,
spatial_scale
=
1.0
,
sampling_ratio
=
0
,
sampling_ratio
=
0
,
gamma
=
0.1
):
gamma
=
0.1
):
super
(
ModulatedDeformRoIPoolPack
,
super
().
__init__
(
output_size
,
spatial_scale
,
sampling_ratio
,
gamma
)
self
).
__init__
(
output_size
,
spatial_scale
,
sampling_ratio
,
gamma
)
self
.
output_channels
=
output_channels
self
.
output_channels
=
output_channels
self
.
deform_fc_channels
=
deform_fc_channels
self
.
deform_fc_channels
=
deform_fc_channels
...
...
mmcv/ops/focal_loss.py
View file @
45fa3e44
...
@@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
...
@@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class
SigmoidFocalLoss
(
nn
.
Module
):
class
SigmoidFocalLoss
(
nn
.
Module
):
def
__init__
(
self
,
gamma
,
alpha
,
weight
=
None
,
reduction
=
'mean'
):
def
__init__
(
self
,
gamma
,
alpha
,
weight
=
None
,
reduction
=
'mean'
):
super
(
SigmoidFocalLoss
,
self
).
__init__
()
super
().
__init__
()
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
alpha
=
alpha
self
.
alpha
=
alpha
self
.
register_buffer
(
'weight'
,
weight
)
self
.
register_buffer
(
'weight'
,
weight
)
...
@@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
...
@@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class
SoftmaxFocalLoss
(
nn
.
Module
):
class
SoftmaxFocalLoss
(
nn
.
Module
):
def
__init__
(
self
,
gamma
,
alpha
,
weight
=
None
,
reduction
=
'mean'
):
def
__init__
(
self
,
gamma
,
alpha
,
weight
=
None
,
reduction
=
'mean'
):
super
(
SoftmaxFocalLoss
,
self
).
__init__
()
super
().
__init__
()
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
alpha
=
alpha
self
.
alpha
=
alpha
self
.
register_buffer
(
'weight'
,
weight
)
self
.
register_buffer
(
'weight'
,
weight
)
...
...
mmcv/ops/fused_bias_leakyrelu.py
View file @
45fa3e44
...
@@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
...
@@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
"""
"""
def
__init__
(
self
,
num_channels
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
def
__init__
(
self
,
num_channels
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
super
(
FusedBiasLeakyReLU
,
self
).
__init__
()
super
().
__init__
()
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
num_channels
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
num_channels
))
self
.
negative_slope
=
negative_slope
self
.
negative_slope
=
negative_slope
...
...
mmcv/ops/masked_conv.py
View file @
45fa3e44
...
@@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
...
@@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
bias
=
True
):
bias
=
True
):
super
(
MaskedConv2d
,
super
().
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
)
padding
,
dilation
,
groups
,
bias
)
def
forward
(
self
,
input
,
mask
=
None
):
def
forward
(
self
,
input
,
mask
=
None
):
if
mask
is
None
:
# fallback to the normal Conv2d
if
mask
is
None
:
# fallback to the normal Conv2d
return
super
(
MaskedConv2d
,
self
).
forward
(
input
)
return
super
().
forward
(
input
)
else
:
else
:
return
masked_conv2d
(
input
,
mask
,
self
.
weight
,
self
.
bias
,
return
masked_conv2d
(
input
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
padding
)
self
.
padding
)
mmcv/ops/merge_cells.py
View file @
45fa3e44
...
@@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
...
@@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
input_conv_cfg
=
None
,
input_conv_cfg
=
None
,
input_norm_cfg
=
None
,
input_norm_cfg
=
None
,
upsample_mode
=
'nearest'
):
upsample_mode
=
'nearest'
):
super
(
BaseMergeCell
,
self
).
__init__
()
super
().
__init__
()
assert
upsample_mode
in
[
'nearest'
,
'bilinear'
]
assert
upsample_mode
in
[
'nearest'
,
'bilinear'
]
self
.
with_out_conv
=
with_out_conv
self
.
with_out_conv
=
with_out_conv
self
.
with_input1_conv
=
with_input1_conv
self
.
with_input1_conv
=
with_input1_conv
...
@@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
...
@@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
class
SumCell
(
BaseMergeCell
):
class
SumCell
(
BaseMergeCell
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
SumCell
,
self
).
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
def
_binary_op
(
self
,
x1
,
x2
):
def
_binary_op
(
self
,
x1
,
x2
):
return
x1
+
x2
return
x1
+
x2
...
@@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
...
@@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
class
ConcatCell
(
BaseMergeCell
):
class
ConcatCell
(
BaseMergeCell
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
ConcatCell
,
self
).
__init__
(
in_channels
*
2
,
out_channels
,
super
().
__init__
(
in_channels
*
2
,
out_channels
,
**
kwargs
)
**
kwargs
)
def
_binary_op
(
self
,
x1
,
x2
):
def
_binary_op
(
self
,
x1
,
x2
):
ret
=
torch
.
cat
([
x1
,
x2
],
dim
=
1
)
ret
=
torch
.
cat
([
x1
,
x2
],
dim
=
1
)
...
...
mmcv/ops/modulated_deform_conv.py
View file @
45fa3e44
...
@@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
...
@@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
groups
=
1
,
groups
=
1
,
deform_groups
=
1
,
deform_groups
=
1
,
bias
=
True
):
bias
=
True
):
super
(
ModulatedDeformConv2d
,
self
).
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
kernel_size
=
_pair
(
kernel_size
)
...
@@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
...
@@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
_version
=
2
_version
=
2
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ModulatedDeformConv2dPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
in_channels
,
self
.
deform_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
self
.
deform_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
...
@@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
...
@@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
self
.
init_weights
()
self
.
init_weights
()
def
init_weights
(
self
):
def
init_weights
(
self
):
super
(
ModulatedDeformConv2dPack
,
self
).
init_weights
()
super
().
init_weights
()
if
hasattr
(
self
,
'conv_offset'
):
if
hasattr
(
self
,
'conv_offset'
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment