Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
45fa3e44
Unverified
Commit
45fa3e44
authored
May 18, 2022
by
Zaida Zhou
Committed by
GitHub
May 18, 2022
Browse files
Add pyupgrade pre-commit hook (#1937)
* add pyupgrade * add options for pyupgrade * minor refinement
parent
c561264d
Changes
110
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
108 additions
and
117 deletions
+108
-117
mmcv/runner/hooks/logger/segmind.py
mmcv/runner/hooks/logger/segmind.py
+1
-2
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+2
-3
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+9
-10
mmcv/runner/hooks/logger/wandb.py
mmcv/runner/hooks/logger/wandb.py
+2
-3
mmcv/runner/hooks/lr_updater.py
mmcv/runner/hooks/lr_updater.py
+12
-12
mmcv/runner/hooks/momentum_updater.py
mmcv/runner/hooks/momentum_updater.py
+6
-6
mmcv/runner/hooks/optimizer.py
mmcv/runner/hooks/optimizer.py
+3
-5
mmcv/runner/iter_based_runner.py
mmcv/runner/iter_based_runner.py
+1
-1
mmcv/tensorrt/tensorrt_utils.py
mmcv/tensorrt/tensorrt_utils.py
+2
-2
mmcv/utils/config.py
mmcv/utils/config.py
+17
-17
mmcv/utils/timer.py
mmcv/utils/timer.py
+1
-1
mmcv/video/optflow.py
mmcv/video/optflow.py
+5
-5
mmcv/visualization/optflow.py
mmcv/visualization/optflow.py
+0
-1
setup.py
setup.py
+3
-4
tests/test_arraymisc.py
tests/test_arraymisc.py
+0
-1
tests/test_cnn/test_conv_module.py
tests/test_cnn/test_conv_module.py
+1
-1
tests/test_fileclient.py
tests/test_fileclient.py
+40
-40
tests/test_fileio.py
tests/test_fileio.py
+1
-1
tests/test_ops/test_bbox.py
tests/test_ops/test_bbox.py
+1
-1
tests/test_ops/test_bilinear_grid_sample.py
tests/test_ops/test_bilinear_grid_sample.py
+1
-1
No files found.
mmcv/runner/hooks/logger/segmind.py
View file @
45fa3e44
...
@@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
...
@@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
ignore_last
=
True
,
ignore_last
=
True
,
reset_flag
=
False
,
reset_flag
=
False
,
by_epoch
=
True
):
by_epoch
=
True
):
super
(
SegmindLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
reset_flag
,
by_epoch
)
self
.
import_segmind
()
self
.
import_segmind
()
def
import_segmind
(
self
):
def
import_segmind
(
self
):
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
45fa3e44
...
@@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
...
@@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
ignore_last
=
True
,
ignore_last
=
True
,
reset_flag
=
False
,
reset_flag
=
False
,
by_epoch
=
True
):
by_epoch
=
True
):
super
(
TensorboardLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
reset_flag
,
by_epoch
)
self
.
log_dir
=
log_dir
self
.
log_dir
=
log_dir
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
super
(
TensorboardLoggerHook
,
self
).
before_run
(
runner
)
super
().
before_run
(
runner
)
if
(
TORCH_VERSION
==
'parrots'
if
(
TORCH_VERSION
==
'parrots'
or
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.1'
)):
or
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.1'
)):
try
:
try
:
...
...
mmcv/runner/hooks/logger/text.py
View file @
45fa3e44
...
@@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
),
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
),
keep_local
=
True
,
keep_local
=
True
,
file_client_args
=
None
):
file_client_args
=
None
):
super
(
TextLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
reset_flag
,
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
by_epoch
)
self
.
by_epoch
=
by_epoch
self
.
by_epoch
=
by_epoch
self
.
time_sec_tot
=
0
self
.
time_sec_tot
=
0
self
.
interval_exp_name
=
interval_exp_name
self
.
interval_exp_name
=
interval_exp_name
...
@@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self
.
out_dir
)
self
.
out_dir
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
super
(
TextLoggerHook
,
self
).
before_run
(
runner
)
super
().
before_run
(
runner
)
if
self
.
out_dir
is
not
None
:
if
self
.
out_dir
is
not
None
:
self
.
file_client
=
FileClient
.
infer_client
(
self
.
file_client_args
,
self
.
file_client
=
FileClient
.
infer_client
(
self
.
file_client_args
,
...
@@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
...
@@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
basename
=
osp
.
basename
(
runner
.
work_dir
.
rstrip
(
osp
.
sep
))
basename
=
osp
.
basename
(
runner
.
work_dir
.
rstrip
(
osp
.
sep
))
self
.
out_dir
=
self
.
file_client
.
join_path
(
self
.
out_dir
,
basename
)
self
.
out_dir
=
self
.
file_client
.
join_path
(
self
.
out_dir
,
basename
)
runner
.
logger
.
info
(
runner
.
logger
.
info
(
(
f
'Text logs will be saved to
{
self
.
out_dir
}
by '
f
'Text logs will be saved to
{
self
.
out_dir
}
by '
f
'
{
self
.
file_client
.
name
}
after the training process.'
)
)
f
'
{
self
.
file_client
.
name
}
after the training process.'
)
self
.
start_iter
=
runner
.
iter
self
.
start_iter
=
runner
.
iter
self
.
json_log_path
=
osp
.
join
(
runner
.
work_dir
,
self
.
json_log_path
=
osp
.
join
(
runner
.
work_dir
,
...
@@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
...
@@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
local_filepath
=
osp
.
join
(
runner
.
work_dir
,
filename
)
local_filepath
=
osp
.
join
(
runner
.
work_dir
,
filename
)
out_filepath
=
self
.
file_client
.
join_path
(
out_filepath
=
self
.
file_client
.
join_path
(
self
.
out_dir
,
filename
)
self
.
out_dir
,
filename
)
with
open
(
local_filepath
,
'r'
)
as
f
:
with
open
(
local_filepath
)
as
f
:
self
.
file_client
.
put_text
(
f
.
read
(),
out_filepath
)
self
.
file_client
.
put_text
(
f
.
read
(),
out_filepath
)
runner
.
logger
.
info
(
runner
.
logger
.
info
(
(
f
'The file
{
local_filepath
}
has been uploaded to '
f
'The file
{
local_filepath
}
has been uploaded to '
f
'
{
out_filepath
}
.'
)
)
f
'
{
out_filepath
}
.'
)
if
not
self
.
keep_local
:
if
not
self
.
keep_local
:
os
.
remove
(
local_filepath
)
os
.
remove
(
local_filepath
)
runner
.
logger
.
info
(
runner
.
logger
.
info
(
(
f
'
{
local_filepath
}
was removed due to the '
f
'
{
local_filepath
}
was removed due to the '
'`self.keep_local=False`'
)
)
'`self.keep_local=False`'
)
mmcv/runner/hooks/logger/wandb.py
View file @
45fa3e44
...
@@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
...
@@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
with_step
=
True
,
with_step
=
True
,
log_artifact
=
True
,
log_artifact
=
True
,
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
)):
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
)):
super
(
WandbLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
reset_flag
,
by_epoch
)
self
.
import_wandb
()
self
.
import_wandb
()
self
.
init_kwargs
=
init_kwargs
self
.
init_kwargs
=
init_kwargs
self
.
commit
=
commit
self
.
commit
=
commit
...
@@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
...
@@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
super
(
WandbLoggerHook
,
self
).
before_run
(
runner
)
super
().
before_run
(
runner
)
if
self
.
wandb
is
None
:
if
self
.
wandb
is
None
:
self
.
import_wandb
()
self
.
import_wandb
()
if
self
.
init_kwargs
:
if
self
.
init_kwargs
:
...
...
mmcv/runner/hooks/lr_updater.py
View file @
45fa3e44
...
@@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
...
@@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
(
FixedLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
return
base_lr
return
base_lr
...
@@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self
.
step
=
step
self
.
step
=
step
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
super
(
StepLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
...
@@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
...
@@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
def
__init__
(
self
,
gamma
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
self
.
gamma
=
gamma
self
.
gamma
=
gamma
super
(
ExpLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
...
@@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
...
@@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
self
.
power
=
power
self
.
power
=
power
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
super
(
PolyLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
...
@@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
power
=
power
self
.
power
=
power
super
(
InvLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
...
@@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
(
CosineAnnealingLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self
.
start_percent
=
start_percent
self
.
start_percent
=
start_percent
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
(
FlatCosineAnnealingLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
self
.
restart_weights
=
restart_weights
self
.
restart_weights
=
restart_weights
assert
(
len
(
self
.
periods
)
==
len
(
self
.
restart_weights
)
assert
(
len
(
self
.
periods
)
==
len
(
self
.
restart_weights
)
),
'periods and restart_weights should have the same length.'
),
'periods and restart_weights should have the same length.'
super
(
CosineRestartLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
cumulative_periods
=
[
self
.
cumulative_periods
=
[
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))
...
@@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
...
@@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
assert
not
by_epoch
,
\
assert
not
by_epoch
,
\
'currently only support "by_epoch" = False'
'currently only support "by_epoch" = False'
super
(
CyclicLrUpdaterHook
,
self
).
__init__
(
by_epoch
,
**
kwargs
)
super
().
__init__
(
by_epoch
,
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
super
(
CyclicLrUpdaterHook
,
self
).
before_run
(
runner
)
super
().
before_run
(
runner
)
# initiate lr_phases
# initiate lr_phases
# total lr_phases are separated as up and down
# total lr_phases are separated as up and down
self
.
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
self
.
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
...
@@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
...
@@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self
.
final_div_factor
=
final_div_factor
self
.
final_div_factor
=
final_div_factor
self
.
three_phase
=
three_phase
self
.
three_phase
=
three_phase
self
.
lr_phases
=
[]
# init lr_phases
self
.
lr_phases
=
[]
# init lr_phases
super
(
OneCycleLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
if
hasattr
(
self
,
'total_steps'
):
if
hasattr
(
self
,
'total_steps'
):
...
@@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
(
LinearAnnealingLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
,
base_lr
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
...
mmcv/runner/hooks/momentum_updater.py
View file @
45fa3e44
...
@@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
...
@@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self
.
step
=
step
self
.
step
=
step
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
min_momentum
=
min_momentum
self
.
min_momentum
=
min_momentum
super
(
StepMomentumUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_momentum
(
self
,
runner
,
base_momentum
):
def
get_momentum
(
self
,
runner
,
base_momentum
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
...
@@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
...
@@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert
(
min_momentum
is
None
)
^
(
min_momentum_ratio
is
None
)
assert
(
min_momentum
is
None
)
^
(
min_momentum_ratio
is
None
)
self
.
min_momentum
=
min_momentum
self
.
min_momentum
=
min_momentum
self
.
min_momentum_ratio
=
min_momentum_ratio
self
.
min_momentum_ratio
=
min_momentum_ratio
super
(
CosineAnnealingMomentumUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_momentum
(
self
,
runner
,
base_momentum
):
def
get_momentum
(
self
,
runner
,
base_momentum
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
...
@@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert
(
min_momentum
is
None
)
^
(
min_momentum_ratio
is
None
)
assert
(
min_momentum
is
None
)
^
(
min_momentum_ratio
is
None
)
self
.
min_momentum
=
min_momentum
self
.
min_momentum
=
min_momentum
self
.
min_momentum_ratio
=
min_momentum_ratio
self
.
min_momentum_ratio
=
min_momentum_ratio
super
(
LinearAnnealingMomentumUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_momentum
(
self
,
runner
,
base_momentum
):
def
get_momentum
(
self
,
runner
,
base_momentum
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
...
@@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
# currently only support by_epoch=False
# currently only support by_epoch=False
assert
not
by_epoch
,
\
assert
not
by_epoch
,
\
'currently only support "by_epoch" = False'
'currently only support "by_epoch" = False'
super
(
CyclicMomentumUpdaterHook
,
self
).
__init__
(
by_epoch
,
**
kwargs
)
super
().
__init__
(
by_epoch
,
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
super
(
CyclicMomentumUpdaterHook
,
self
).
before_run
(
runner
)
super
().
before_run
(
runner
)
# initiate momentum_phases
# initiate momentum_phases
# total momentum_phases are separated as up and down
# total momentum_phases are separated as up and down
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
...
@@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
...
@@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
self
.
anneal_func
=
annealing_linear
self
.
anneal_func
=
annealing_linear
self
.
three_phase
=
three_phase
self
.
three_phase
=
three_phase
self
.
momentum_phases
=
[]
# init momentum_phases
self
.
momentum_phases
=
[]
# init momentum_phases
super
(
OneCycleMomentumUpdaterHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
if
isinstance
(
runner
.
optimizer
,
dict
):
if
isinstance
(
runner
.
optimizer
,
dict
):
...
...
mmcv/runner/hooks/optimizer.py
View file @
45fa3e44
...
@@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
...
@@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
"""
"""
def
__init__
(
self
,
cumulative_iters
=
1
,
**
kwargs
):
def
__init__
(
self
,
cumulative_iters
=
1
,
**
kwargs
):
super
(
GradientCumulativeOptimizerHook
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
assert
isinstance
(
cumulative_iters
,
int
)
and
cumulative_iters
>
0
,
\
assert
isinstance
(
cumulative_iters
,
int
)
and
cumulative_iters
>
0
,
\
f
'cumulative_iters only accepts positive int, but got '
\
f
'cumulative_iters only accepts positive int, but got '
\
...
@@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
...
@@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
GradientCumulativeFp16OptimizerHook
,
super
().
__init__
(
*
args
,
**
kwargs
)
self
).
__init__
(
*
args
,
**
kwargs
)
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
):
if
not
self
.
initialized
:
if
not
self
.
initialized
:
...
@@ -490,8 +489,7 @@ else:
...
@@ -490,8 +489,7 @@ else:
iters gradient cumulating."""
iters gradient cumulating."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
GradientCumulativeFp16OptimizerHook
,
super
().
__init__
(
*
args
,
**
kwargs
)
self
).
__init__
(
*
args
,
**
kwargs
)
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
):
if
not
self
.
initialized
:
if
not
self
.
initialized
:
...
...
mmcv/runner/iter_based_runner.py
View file @
45fa3e44
...
@@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
...
@@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
if
log_config
is
not
None
:
if
log_config
is
not
None
:
for
info
in
log_config
[
'hooks'
]:
for
info
in
log_config
[
'hooks'
]:
info
.
setdefault
(
'by_epoch'
,
False
)
info
.
setdefault
(
'by_epoch'
,
False
)
super
(
IterBasedRunner
,
self
).
register_training_hooks
(
super
().
register_training_hooks
(
lr_config
=
lr_config
,
lr_config
=
lr_config
,
momentum_config
=
momentum_config
,
momentum_config
=
momentum_config
,
optimizer_config
=
optimizer_config
,
optimizer_config
=
optimizer_config
,
...
...
mmcv/tensorrt/tensorrt_utils.py
View file @
45fa3e44
...
@@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
...
@@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
msg
+=
reset_style
msg
+=
reset_style
warnings
.
warn
(
msg
)
warnings
.
warn
(
msg
)
device
=
torch
.
device
(
'cuda:{
}'
.
format
(
device_id
)
)
device
=
torch
.
device
(
f
'cuda:
{
device_id
}
'
)
# create builder and network
# create builder and network
logger
=
trt
.
Logger
(
log_level
)
logger
=
trt
.
Logger
(
log_level
)
builder
=
trt
.
Builder
(
logger
)
builder
=
trt
.
Builder
(
logger
)
...
@@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
...
@@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
msg
+=
reset_style
msg
+=
reset_style
warnings
.
warn
(
msg
)
warnings
.
warn
(
msg
)
super
(
TRTWrapper
,
self
).
__init__
()
super
().
__init__
()
self
.
engine
=
engine
self
.
engine
=
engine
if
isinstance
(
self
.
engine
,
str
):
if
isinstance
(
self
.
engine
,
str
):
self
.
engine
=
load_trt_engine
(
engine
)
self
.
engine
=
load_trt_engine
(
engine
)
...
...
mmcv/utils/config.py
View file @
45fa3e44
...
@@ -39,7 +39,7 @@ class ConfigDict(Dict):
...
@@ -39,7 +39,7 @@ class ConfigDict(Dict):
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
try
:
try
:
value
=
super
(
ConfigDict
,
self
).
__getattr__
(
name
)
value
=
super
().
__getattr__
(
name
)
except
KeyError
:
except
KeyError
:
ex
=
AttributeError
(
f
"'
{
self
.
__class__
.
__name__
}
' object has no "
ex
=
AttributeError
(
f
"'
{
self
.
__class__
.
__name__
}
' object has no "
f
"attribute '
{
name
}
'"
)
f
"attribute '
{
name
}
'"
)
...
@@ -96,7 +96,7 @@ class Config:
...
@@ -96,7 +96,7 @@ class Config:
@
staticmethod
@
staticmethod
def
_validate_py_syntax
(
filename
):
def
_validate_py_syntax
(
filename
):
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
filename
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
# Setting encoding explicitly to resolve coding issue on windows
content
=
f
.
read
()
content
=
f
.
read
()
try
:
try
:
...
@@ -116,7 +116,7 @@ class Config:
...
@@ -116,7 +116,7 @@ class Config:
fileBasename
=
file_basename
,
fileBasename
=
file_basename
,
fileBasenameNoExtension
=
file_basename_no_extension
,
fileBasenameNoExtension
=
file_basename_no_extension
,
fileExtname
=
file_extname
)
fileExtname
=
file_extname
)
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
filename
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
# Setting encoding explicitly to resolve coding issue on windows
config_file
=
f
.
read
()
config_file
=
f
.
read
()
for
key
,
value
in
support_templates
.
items
():
for
key
,
value
in
support_templates
.
items
():
...
@@ -130,7 +130,7 @@ class Config:
...
@@ -130,7 +130,7 @@ class Config:
def
_pre_substitute_base_vars
(
filename
,
temp_config_name
):
def
_pre_substitute_base_vars
(
filename
,
temp_config_name
):
"""Substitute base variable placehoders to string, so that parsing
"""Substitute base variable placehoders to string, so that parsing
would work."""
would work."""
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
filename
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
# Setting encoding explicitly to resolve coding issue on windows
config_file
=
f
.
read
()
config_file
=
f
.
read
()
base_var_dict
=
{}
base_var_dict
=
{}
...
@@ -183,7 +183,7 @@ class Config:
...
@@ -183,7 +183,7 @@ class Config:
check_file_exist
(
filename
)
check_file_exist
(
filename
)
fileExtname
=
osp
.
splitext
(
filename
)[
1
]
fileExtname
=
osp
.
splitext
(
filename
)[
1
]
if
fileExtname
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
if
fileExtname
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
raise
I
OError
(
'Only py/yml/yaml/json type are supported now!'
)
raise
O
S
Error
(
'Only py/yml/yaml/json type are supported now!'
)
with
tempfile
.
TemporaryDirectory
()
as
temp_config_dir
:
with
tempfile
.
TemporaryDirectory
()
as
temp_config_dir
:
temp_config_file
=
tempfile
.
NamedTemporaryFile
(
temp_config_file
=
tempfile
.
NamedTemporaryFile
(
...
@@ -236,7 +236,7 @@ class Config:
...
@@ -236,7 +236,7 @@ class Config:
warnings
.
warn
(
warning_msg
,
DeprecationWarning
)
warnings
.
warn
(
warning_msg
,
DeprecationWarning
)
cfg_text
=
filename
+
'
\n
'
cfg_text
=
filename
+
'
\n
'
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
filename
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
# Setting encoding explicitly to resolve coding issue on windows
cfg_text
+=
f
.
read
()
cfg_text
+=
f
.
read
()
...
@@ -356,7 +356,7 @@ class Config:
...
@@ -356,7 +356,7 @@ class Config:
:obj:`Config`: Config obj.
:obj:`Config`: Config obj.
"""
"""
if
file_format
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
if
file_format
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
raise
I
OError
(
'Only py/yml/yaml/json type are supported now!'
)
raise
O
S
Error
(
'Only py/yml/yaml/json type are supported now!'
)
if
file_format
!=
'.py'
and
'dict('
in
cfg_str
:
if
file_format
!=
'.py'
and
'dict('
in
cfg_str
:
# check if users specify a wrong suffix for python
# check if users specify a wrong suffix for python
warnings
.
warn
(
warnings
.
warn
(
...
@@ -396,16 +396,16 @@ class Config:
...
@@ -396,16 +396,16 @@ class Config:
if
isinstance
(
filename
,
Path
):
if
isinstance
(
filename
,
Path
):
filename
=
str
(
filename
)
filename
=
str
(
filename
)
super
(
Config
,
self
).
__setattr__
(
'_cfg_dict'
,
ConfigDict
(
cfg_dict
))
super
().
__setattr__
(
'_cfg_dict'
,
ConfigDict
(
cfg_dict
))
super
(
Config
,
self
).
__setattr__
(
'_filename'
,
filename
)
super
().
__setattr__
(
'_filename'
,
filename
)
if
cfg_text
:
if
cfg_text
:
text
=
cfg_text
text
=
cfg_text
elif
filename
:
elif
filename
:
with
open
(
filename
,
'r'
)
as
f
:
with
open
(
filename
)
as
f
:
text
=
f
.
read
()
text
=
f
.
read
()
else
:
else
:
text
=
''
text
=
''
super
(
Config
,
self
).
__setattr__
(
'_text'
,
text
)
super
().
__setattr__
(
'_text'
,
text
)
@
property
@
property
def
filename
(
self
):
def
filename
(
self
):
...
@@ -556,9 +556,9 @@ class Config:
...
@@ -556,9 +556,9 @@ class Config:
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
_cfg_dict
,
_filename
,
_text
=
state
_cfg_dict
,
_filename
,
_text
=
state
super
(
Config
,
self
).
__setattr__
(
'_cfg_dict'
,
_cfg_dict
)
super
().
__setattr__
(
'_cfg_dict'
,
_cfg_dict
)
super
(
Config
,
self
).
__setattr__
(
'_filename'
,
_filename
)
super
().
__setattr__
(
'_filename'
,
_filename
)
super
(
Config
,
self
).
__setattr__
(
'_text'
,
_text
)
super
().
__setattr__
(
'_text'
,
_text
)
def
dump
(
self
,
file
=
None
):
def
dump
(
self
,
file
=
None
):
"""Dumps config into a file or returns a string representation of the
"""Dumps config into a file or returns a string representation of the
...
@@ -584,7 +584,7 @@ class Config:
...
@@ -584,7 +584,7 @@ class Config:
will be dumped. Defaults to None.
will be dumped. Defaults to None.
"""
"""
import
mmcv
import
mmcv
cfg_dict
=
super
(
Config
,
self
).
__getattribute__
(
'_cfg_dict'
).
to_dict
()
cfg_dict
=
super
().
__getattribute__
(
'_cfg_dict'
).
to_dict
()
if
file
is
None
:
if
file
is
None
:
if
self
.
filename
is
None
or
self
.
filename
.
endswith
(
'.py'
):
if
self
.
filename
is
None
or
self
.
filename
.
endswith
(
'.py'
):
return
self
.
pretty_text
return
self
.
pretty_text
...
@@ -638,8 +638,8 @@ class Config:
...
@@ -638,8 +638,8 @@ class Config:
subkey
=
key_list
[
-
1
]
subkey
=
key_list
[
-
1
]
d
[
subkey
]
=
v
d
[
subkey
]
=
v
cfg_dict
=
super
(
Config
,
self
).
__getattribute__
(
'_cfg_dict'
)
cfg_dict
=
super
().
__getattribute__
(
'_cfg_dict'
)
super
(
Config
,
self
).
__setattr__
(
super
().
__setattr__
(
'_cfg_dict'
,
'_cfg_dict'
,
Config
.
_merge_a_into_b
(
Config
.
_merge_a_into_b
(
option_cfg_dict
,
cfg_dict
,
allow_list_keys
=
allow_list_keys
))
option_cfg_dict
,
cfg_dict
,
allow_list_keys
=
allow_list_keys
))
...
...
mmcv/utils/timer.py
View file @
45fa3e44
...
@@ -6,7 +6,7 @@ class TimerError(Exception):
...
@@ -6,7 +6,7 @@ class TimerError(Exception):
def
__init__
(
self
,
message
):
def
__init__
(
self
,
message
):
self
.
message
=
message
self
.
message
=
message
super
(
TimerError
,
self
).
__init__
(
message
)
super
().
__init__
(
message
)
class
Timer
:
class
Timer
:
...
...
mmcv/video/optflow.py
View file @
45fa3e44
...
@@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
...
@@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
try
:
try
:
header
=
f
.
read
(
4
).
decode
(
'utf-8'
)
header
=
f
.
read
(
4
).
decode
(
'utf-8'
)
except
Exception
:
except
Exception
:
raise
I
OError
(
f
'Invalid flow file:
{
flow_or_path
}
'
)
raise
O
S
Error
(
f
'Invalid flow file:
{
flow_or_path
}
'
)
else
:
else
:
if
header
!=
'PIEH'
:
if
header
!=
'PIEH'
:
raise
I
OError
(
f
'Invalid flow file:
{
flow_or_path
}
, '
raise
O
S
Error
(
f
'Invalid flow file:
{
flow_or_path
}
, '
'header does not contain PIEH'
)
'header does not contain PIEH'
)
w
=
np
.
fromfile
(
f
,
np
.
int32
,
1
).
squeeze
()
w
=
np
.
fromfile
(
f
,
np
.
int32
,
1
).
squeeze
()
...
@@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
...
@@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
assert
concat_axis
in
[
0
,
1
]
assert
concat_axis
in
[
0
,
1
]
cat_flow
=
imread
(
flow_or_path
,
flag
=
'unchanged'
)
cat_flow
=
imread
(
flow_or_path
,
flag
=
'unchanged'
)
if
cat_flow
.
ndim
!=
2
:
if
cat_flow
.
ndim
!=
2
:
raise
I
OError
(
raise
O
S
Error
(
f
'
{
flow_or_path
}
is not a valid quantized flow file, '
f
'
{
flow_or_path
}
is not a valid quantized flow file, '
f
'its dimension is
{
cat_flow
.
ndim
}
.'
)
f
'its dimension is
{
cat_flow
.
ndim
}
.'
)
assert
cat_flow
.
shape
[
concat_axis
]
%
2
==
0
assert
cat_flow
.
shape
[
concat_axis
]
%
2
==
0
...
@@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
...
@@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
"""
"""
if
not
quantize
:
if
not
quantize
:
with
open
(
filename
,
'wb'
)
as
f
:
with
open
(
filename
,
'wb'
)
as
f
:
f
.
write
(
'PIEH'
.
encode
(
'utf-8'
)
)
f
.
write
(
b
'PIEH'
)
np
.
array
([
flow
.
shape
[
1
],
flow
.
shape
[
0
]],
dtype
=
np
.
int32
).
tofile
(
f
)
np
.
array
([
flow
.
shape
[
1
],
flow
.
shape
[
0
]],
dtype
=
np
.
int32
).
tofile
(
f
)
flow
=
flow
.
astype
(
np
.
float32
)
flow
=
flow
.
astype
(
np
.
float32
)
flow
.
tofile
(
f
)
flow
.
tofile
(
f
)
...
@@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
...
@@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
assert
dx
.
shape
==
dy
.
shape
assert
dx
.
shape
==
dy
.
shape
assert
dx
.
ndim
==
2
or
(
dx
.
ndim
==
3
and
dx
.
shape
[
-
1
]
==
1
)
assert
dx
.
ndim
==
2
or
(
dx
.
ndim
==
3
and
dx
.
shape
[
-
1
]
==
1
)
dx
,
dy
=
[
dequantize
(
d
,
-
max_val
,
max_val
,
255
)
for
d
in
[
dx
,
dy
]
]
dx
,
dy
=
(
dequantize
(
d
,
-
max_val
,
max_val
,
255
)
for
d
in
[
dx
,
dy
]
)
if
denorm
:
if
denorm
:
dx
*=
dx
.
shape
[
1
]
dx
*=
dx
.
shape
[
1
]
...
...
mmcv/visualization/optflow.py
View file @
45fa3e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
__future__
import
division
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
...
setup.py
View file @
45fa3e44
...
@@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
...
@@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
def
get_version
():
def
get_version
():
version_file
=
'mmcv/version.py'
version_file
=
'mmcv/version.py'
with
open
(
version_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
return
locals
()[
'__version__'
]
return
locals
()[
'__version__'
]
...
@@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
...
@@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
yield
info
yield
info
def
parse_require_file
(
fpath
):
def
parse_require_file
(
fpath
):
with
open
(
fpath
,
'r'
)
as
f
:
with
open
(
fpath
)
as
f
:
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
line
=
line
.
strip
()
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
'#'
):
if
line
and
not
line
.
startswith
(
'#'
):
for
info
in
parse_line
(
line
):
yield
from
parse_line
(
line
)
yield
info
def
gen_packages_items
():
def
gen_packages_items
():
if
exists
(
require_fpath
):
if
exists
(
require_fpath
):
...
...
tests/test_arraymisc.py
View file @
45fa3e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
__future__
import
division
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
...
tests/test_cnn/test_conv_module.py
View file @
45fa3e44
...
@@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
...
@@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
groups
=
1
,
groups
=
1
,
bias
=
True
,
bias
=
True
,
norm_cfg
=
None
):
norm_cfg
=
None
):
super
(
ExampleConv
,
self
).
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
kernel_size
=
kernel_size
...
...
tests/test_fileclient.py
View file @
45fa3e44
...
@@ -202,21 +202,22 @@ class TestFileClient:
...
@@ -202,21 +202,22 @@ class TestFileClient:
# test `list_dir_or_file`
# test `list_dir_or_file`
with
build_temporary_directory
()
as
tmp_dir
:
with
build_temporary_directory
()
as
tmp_dir
:
# 1. list directories and files
# 1. list directories and files
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
))
==
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
))
==
{
[
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
])
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
}
# 2. list directories and files recursively
# 2. list directories and files recursively
assert
set
(
disk_backend
.
list_dir_or_file
(
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
recursive
=
True
))
==
set
([
tmp_dir
,
recursive
=
True
))
==
{
'dir1'
,
'dir1'
,
osp
.
join
(
'dir1'
,
'text3.txt'
),
'dir2'
,
osp
.
join
(
'dir1'
,
'text3.txt'
),
'dir2'
,
osp
.
join
(
'dir2'
,
'dir3'
),
osp
.
join
(
'dir2'
,
'dir3'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
])
}
# 3. only list directories
# 3. only list directories
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
))
==
set
([
'dir1'
,
'dir2'
])
tmp_dir
,
list_file
=
False
))
==
{
'dir1'
,
'dir2'
}
with
pytest
.
raises
(
with
pytest
.
raises
(
TypeError
,
TypeError
,
match
=
'`suffix` should be None when `list_dir` is True'
):
match
=
'`suffix` should be None when `list_dir` is True'
):
...
@@ -227,30 +228,30 @@ class TestFileClient:
...
@@ -227,30 +228,30 @@ class TestFileClient:
# 4. only list directories recursively
# 4. only list directories recursively
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
set
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
{
[
'dir1'
,
'dir2'
,
'dir1'
,
'dir2'
,
osp
.
join
(
'dir2'
,
'dir3'
)])
osp
.
join
(
'dir2'
,
'dir3'
)
}
# 5. only list files
# 5. only list files
assert
set
(
disk_backend
.
list_dir_or_file
(
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
))
==
set
([
'text1.txt'
,
'text2.txt'
])
tmp_dir
,
list_dir
=
False
))
==
{
'text1.txt'
,
'text2.txt'
}
# 6. only list files recursively
# 6. only list files recursively
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
set
([
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
])
}
# 7. only list files ending with suffix
# 7. only list files ending with suffix
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
))
==
set
([
'text1.txt'
,
'text2.txt'
])
suffix
=
'.txt'
))
==
{
'text1.txt'
,
'text2.txt'
}
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
suffix
=
(
'.txt'
,
'.jpg'
)))
==
{
'text1.txt'
,
'text2.txt'
}
'.jpg'
)))
==
set
([
'text1.txt'
,
'text2.txt'
])
with
pytest
.
raises
(
with
pytest
.
raises
(
TypeError
,
TypeError
,
match
=
'`suffix` must be a string or tuple of strings'
):
match
=
'`suffix` must be a string or tuple of strings'
):
...
@@ -260,22 +261,22 @@ class TestFileClient:
...
@@ -260,22 +261,22 @@ class TestFileClient:
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
recursive
=
True
))
==
set
([
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
'text1.txt'
,
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
'text1.txt'
,
'text2.txt'
'text2.txt'
])
}
# 7. only list files ending with suffix
# 7. only list files ending with suffix
assert
set
(
assert
set
(
disk_backend
.
list_dir_or_file
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
tmp_dir
,
list_dir
=
False
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
),
suffix
=
(
'.txt'
,
'.jpg'
),
recursive
=
True
))
==
set
([
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
])
}
@
patch
(
'ceph.S3Client'
,
MockS3Client
)
@
patch
(
'ceph.S3Client'
,
MockS3Client
)
def
test_ceph_backend
(
self
):
def
test_ceph_backend
(
self
):
...
@@ -463,21 +464,21 @@ class TestFileClient:
...
@@ -463,21 +464,21 @@ class TestFileClient:
with
build_temporary_directory
()
as
tmp_dir
:
with
build_temporary_directory
()
as
tmp_dir
:
# 1. list directories and files
# 1. list directories and files
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
))
==
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
))
==
{
[
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
])
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
}
# 2. list directories and files recursively
# 2. list directories and files recursively
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
recursive
=
True
))
==
{
tmp_dir
,
recursive
=
True
))
==
set
([
'dir1'
,
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'dir2'
,
'/'
.
join
(
'dir1'
,
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'dir2'
,
(
'dir2'
,
'dir3'
)),
'/'
.
join
(
'/'
.
join
((
'dir2'
,
'dir3'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
])
}
# 3. only list directories
# 3. only list directories
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
))
==
set
([
'dir1'
,
'dir2'
])
tmp_dir
,
list_file
=
False
))
==
{
'dir1'
,
'dir2'
}
with
pytest
.
raises
(
with
pytest
.
raises
(
TypeError
,
TypeError
,
match
=
(
'`list_dir` should be False when `suffix` is not '
match
=
(
'`list_dir` should be False when `suffix` is not '
...
@@ -489,31 +490,30 @@ class TestFileClient:
...
@@ -489,31 +490,30 @@ class TestFileClient:
# 4. only list directories recursively
# 4. only list directories recursively
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
set
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
{
[
'dir1'
,
'dir2'
,
'/'
.
join
((
'dir2'
,
'dir3'
))])
'dir1'
,
'dir2'
,
'/'
.
join
((
'dir2'
,
'dir3'
))
}
# 5. only list files
# 5. only list files
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
petrel_backend
.
list_dir_or_file
(
list_dir
=
False
))
==
set
(
tmp_dir
,
list_dir
=
False
))
==
{
'text1.txt'
,
'text2.txt'
}
[
'text1.txt'
,
'text2.txt'
])
# 6. only list files recursively
# 6. only list files recursively
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
set
([
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
])
}
# 7. only list files ending with suffix
# 7. only list files ending with suffix
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
))
==
set
([
'text1.txt'
,
'text2.txt'
])
suffix
=
'.txt'
))
==
{
'text1.txt'
,
'text2.txt'
}
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
suffix
=
(
'.txt'
,
'.jpg'
)))
==
{
'text1.txt'
,
'text2.txt'
}
'.jpg'
)))
==
set
([
'text1.txt'
,
'text2.txt'
])
with
pytest
.
raises
(
with
pytest
.
raises
(
TypeError
,
TypeError
,
match
=
'`suffix` must be a string or tuple of strings'
):
match
=
'`suffix` must be a string or tuple of strings'
):
...
@@ -523,22 +523,22 @@ class TestFileClient:
...
@@ -523,22 +523,22 @@ class TestFileClient:
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
recursive
=
True
))
==
set
([
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'text1.txt'
,
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'text1.txt'
,
'text2.txt'
'text2.txt'
])
}
# 7. only list files ending with suffix
# 7. only list files ending with suffix
assert
set
(
assert
set
(
petrel_backend
.
list_dir_or_file
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
tmp_dir
,
list_dir
=
False
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
),
suffix
=
(
'.txt'
,
'.jpg'
),
recursive
=
True
))
==
set
([
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
])
}
@
patch
(
'mc.MemcachedClient.GetInstance'
,
MockMemcachedClient
)
@
patch
(
'mc.MemcachedClient.GetInstance'
,
MockMemcachedClient
)
@
patch
(
'mc.pyvector'
,
MagicMock
)
@
patch
(
'mc.pyvector'
,
MagicMock
)
...
...
tests/test_fileio.py
View file @
45fa3e44
...
@@ -128,7 +128,7 @@ def test_register_handler():
...
@@ -128,7 +128,7 @@ def test_register_handler():
assert
content
==
'1.jpg
\n
2.jpg
\n
3.jpg
\n
4.jpg
\n
5.jpg'
assert
content
==
'1.jpg
\n
2.jpg
\n
3.jpg
\n
4.jpg
\n
5.jpg'
tmp_filename
=
osp
.
join
(
tempfile
.
gettempdir
(),
'mmcv_test.txt2'
)
tmp_filename
=
osp
.
join
(
tempfile
.
gettempdir
(),
'mmcv_test.txt2'
)
mmcv
.
dump
(
content
,
tmp_filename
)
mmcv
.
dump
(
content
,
tmp_filename
)
with
open
(
tmp_filename
,
'r'
)
as
f
:
with
open
(
tmp_filename
)
as
f
:
written
=
f
.
read
()
written
=
f
.
read
()
os
.
remove
(
tmp_filename
)
os
.
remove
(
tmp_filename
)
assert
written
==
'
\n
'
+
content
assert
written
==
'
\n
'
+
content
...
...
tests/test_ops/test_bbox.py
View file @
45fa3e44
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
class
TestBBox
(
object
)
:
class
TestBBox
:
def
_test_bbox_overlaps
(
self
,
device
=
'cpu'
,
dtype
=
torch
.
float
):
def
_test_bbox_overlaps
(
self
,
device
=
'cpu'
,
dtype
=
torch
.
float
):
from
mmcv.ops
import
bbox_overlaps
from
mmcv.ops
import
bbox_overlaps
...
...
tests/test_ops/test_bilinear_grid_sample.py
View file @
45fa3e44
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
class
TestBilinearGridSample
(
object
)
:
class
TestBilinearGridSample
:
def
_test_bilinear_grid_sample
(
self
,
def
_test_bilinear_grid_sample
(
self
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment