Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ea173c9f
Unverified
Commit
ea173c9f
authored
May 30, 2022
by
tripleMu
Committed by
GitHub
May 30, 2022
Browse files
Add type hints for mmcv/runer/hooks/logger (#2000)
* Fix * Fix
parent
c70fafeb
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
135 additions
and
119 deletions
+135
-119
mmcv/runner/hooks/logger/base.py
mmcv/runner/hooks/logger/base.py
+26
-21
mmcv/runner/hooks/logger/clearml.py
mmcv/runner/hooks/logger/clearml.py
+9
-7
mmcv/runner/hooks/logger/dvclive.py
mmcv/runner/hooks/logger/dvclive.py
+9
-8
mmcv/runner/hooks/logger/mlflow.py
mmcv/runner/hooks/logger/mlflow.py
+13
-11
mmcv/runner/hooks/logger/neptune.py
mmcv/runner/hooks/logger/neptune.py
+15
-13
mmcv/runner/hooks/logger/pavi.py
mmcv/runner/hooks/logger/pavi.py
+14
-13
mmcv/runner/hooks/logger/segmind.py
mmcv/runner/hooks/logger/segmind.py
+5
-5
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+9
-8
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+19
-18
mmcv/runner/hooks/logger/wandb.py
mmcv/runner/hooks/logger/wandb.py
+16
-15
No files found.
mmcv/runner/hooks/logger/base.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
numbers
import
numbers
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
Dict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -23,10 +24,10 @@ class LoggerHook(Hook):
...
@@ -23,10 +24,10 @@ class LoggerHook(Hook):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
def
__init__
(
self
,
def
__init__
(
self
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
by_epoch
:
bool
=
True
):
self
.
interval
=
interval
self
.
interval
=
interval
self
.
ignore_last
=
ignore_last
self
.
ignore_last
=
ignore_last
self
.
reset_flag
=
reset_flag
self
.
reset_flag
=
reset_flag
...
@@ -37,7 +38,9 @@ class LoggerHook(Hook):
...
@@ -37,7 +38,9 @@ class LoggerHook(Hook):
pass
pass
@
staticmethod
@
staticmethod
def
is_scalar
(
val
,
include_np
=
True
,
include_torch
=
True
):
def
is_scalar
(
val
,
include_np
:
bool
=
True
,
include_torch
:
bool
=
True
)
->
bool
:
"""Tell the input variable is a scalar or not.
"""Tell the input variable is a scalar or not.
Args:
Args:
...
@@ -57,7 +60,7 @@ class LoggerHook(Hook):
...
@@ -57,7 +60,7 @@ class LoggerHook(Hook):
else
:
else
:
return
False
return
False
def
get_mode
(
self
,
runner
):
def
get_mode
(
self
,
runner
)
->
str
:
if
runner
.
mode
==
'train'
:
if
runner
.
mode
==
'train'
:
if
'time'
in
runner
.
log_buffer
.
output
:
if
'time'
in
runner
.
log_buffer
.
output
:
mode
=
'train'
mode
=
'train'
...
@@ -70,7 +73,7 @@ class LoggerHook(Hook):
...
@@ -70,7 +73,7 @@ class LoggerHook(Hook):
f
'but got
{
runner
.
mode
}
'
)
f
'but got
{
runner
.
mode
}
'
)
return
mode
return
mode
def
get_epoch
(
self
,
runner
):
def
get_epoch
(
self
,
runner
)
->
int
:
if
runner
.
mode
==
'train'
:
if
runner
.
mode
==
'train'
:
epoch
=
runner
.
epoch
+
1
epoch
=
runner
.
epoch
+
1
elif
runner
.
mode
==
'val'
:
elif
runner
.
mode
==
'val'
:
...
@@ -82,7 +85,7 @@ class LoggerHook(Hook):
...
@@ -82,7 +85,7 @@ class LoggerHook(Hook):
f
'but got
{
runner
.
mode
}
'
)
f
'but got
{
runner
.
mode
}
'
)
return
epoch
return
epoch
def
get_iter
(
self
,
runner
,
inner_iter
=
False
)
:
def
get_iter
(
self
,
runner
,
inner_iter
:
bool
=
False
)
->
int
:
"""Get the current training iteration step."""
"""Get the current training iteration step."""
if
self
.
by_epoch
and
inner_iter
:
if
self
.
by_epoch
and
inner_iter
:
current_iter
=
runner
.
inner_iter
+
1
current_iter
=
runner
.
inner_iter
+
1
...
@@ -90,7 +93,7 @@ class LoggerHook(Hook):
...
@@ -90,7 +93,7 @@ class LoggerHook(Hook):
current_iter
=
runner
.
iter
+
1
current_iter
=
runner
.
iter
+
1
return
current_iter
return
current_iter
def
get_lr_tags
(
self
,
runner
):
def
get_lr_tags
(
self
,
runner
)
->
Dict
[
str
,
float
]
:
tags
=
{}
tags
=
{}
lrs
=
runner
.
current_lr
()
lrs
=
runner
.
current_lr
()
if
isinstance
(
lrs
,
dict
):
if
isinstance
(
lrs
,
dict
):
...
@@ -100,7 +103,7 @@ class LoggerHook(Hook):
...
@@ -100,7 +103,7 @@ class LoggerHook(Hook):
tags
[
'learning_rate'
]
=
lrs
[
0
]
tags
[
'learning_rate'
]
=
lrs
[
0
]
return
tags
return
tags
def
get_momentum_tags
(
self
,
runner
):
def
get_momentum_tags
(
self
,
runner
)
->
Dict
[
str
,
float
]
:
tags
=
{}
tags
=
{}
momentums
=
runner
.
current_momentum
()
momentums
=
runner
.
current_momentum
()
if
isinstance
(
momentums
,
dict
):
if
isinstance
(
momentums
,
dict
):
...
@@ -110,12 +113,14 @@ class LoggerHook(Hook):
...
@@ -110,12 +113,14 @@ class LoggerHook(Hook):
tags
[
'momentum'
]
=
momentums
[
0
]
tags
[
'momentum'
]
=
momentums
[
0
]
return
tags
return
tags
def
get_loggable_tags
(
self
,
def
get_loggable_tags
(
self
,
runner
,
runner
,
allow_scalar
=
True
,
allow_scalar
:
bool
=
True
,
allow_text
=
False
,
allow_text
:
bool
=
False
,
add_mode
=
True
,
add_mode
:
bool
=
True
,
tags_to_skip
=
(
'time'
,
'data_time'
)):
tags_to_skip
:
tuple
=
(
'time'
,
'data_time'
)
)
->
Dict
:
tags
=
{}
tags
=
{}
for
var
,
val
in
runner
.
log_buffer
.
output
.
items
():
for
var
,
val
in
runner
.
log_buffer
.
output
.
items
():
if
var
in
tags_to_skip
:
if
var
in
tags_to_skip
:
...
@@ -131,16 +136,16 @@ class LoggerHook(Hook):
...
@@ -131,16 +136,16 @@ class LoggerHook(Hook):
tags
.
update
(
self
.
get_momentum_tags
(
runner
))
tags
.
update
(
self
.
get_momentum_tags
(
runner
))
return
tags
return
tags
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
for
hook
in
runner
.
hooks
[::
-
1
]:
for
hook
in
runner
.
hooks
[::
-
1
]:
if
isinstance
(
hook
,
LoggerHook
):
if
isinstance
(
hook
,
LoggerHook
):
hook
.
reset_flag
=
True
hook
.
reset_flag
=
True
break
break
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
)
->
None
:
runner
.
log_buffer
.
clear
()
# clear logs of last epoch
runner
.
log_buffer
.
clear
()
# clear logs of last epoch
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
)
->
None
:
if
self
.
by_epoch
and
self
.
every_n_inner_iters
(
runner
,
self
.
interval
):
if
self
.
by_epoch
and
self
.
every_n_inner_iters
(
runner
,
self
.
interval
):
runner
.
log_buffer
.
average
(
self
.
interval
)
runner
.
log_buffer
.
average
(
self
.
interval
)
elif
not
self
.
by_epoch
and
self
.
every_n_iters
(
runner
,
self
.
interval
):
elif
not
self
.
by_epoch
and
self
.
every_n_iters
(
runner
,
self
.
interval
):
...
@@ -154,13 +159,13 @@ class LoggerHook(Hook):
...
@@ -154,13 +159,13 @@ class LoggerHook(Hook):
if
self
.
reset_flag
:
if
self
.
reset_flag
:
runner
.
log_buffer
.
clear_output
()
runner
.
log_buffer
.
clear_output
()
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
)
->
None
:
if
runner
.
log_buffer
.
ready
:
if
runner
.
log_buffer
.
ready
:
self
.
log
(
runner
)
self
.
log
(
runner
)
if
self
.
reset_flag
:
if
self
.
reset_flag
:
runner
.
log_buffer
.
clear_output
()
runner
.
log_buffer
.
clear_output
()
def
after_val_epoch
(
self
,
runner
):
def
after_val_epoch
(
self
,
runner
)
->
None
:
runner
.
log_buffer
.
average
()
runner
.
log_buffer
.
average
()
self
.
log
(
runner
)
self
.
log
(
runner
)
if
self
.
reset_flag
:
if
self
.
reset_flag
:
...
...
mmcv/runner/hooks/logger/clearml.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
...
@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook):
...
@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_kwargs
=
None
,
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_clearml
()
self
.
import_clearml
()
self
.
init_kwargs
=
init_kwargs
self
.
init_kwargs
=
init_kwargs
...
@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook):
...
@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook):
self
.
clearml
=
clearml
self
.
clearml
=
clearml
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
task_kwargs
=
self
.
init_kwargs
if
self
.
init_kwargs
else
{}
task_kwargs
=
self
.
init_kwargs
if
self
.
init_kwargs
else
{}
self
.
task
=
self
.
clearml
.
Task
.
init
(
**
task_kwargs
)
self
.
task
=
self
.
clearml
.
Task
.
init
(
**
task_kwargs
)
self
.
task_logger
=
self
.
task
.
get_logger
()
self
.
task_logger
=
self
.
task
.
get_logger
()
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
for
tag
,
val
in
tags
.
items
():
for
tag
,
val
in
tags
.
items
():
self
.
task_logger
.
report_scalar
(
tag
,
tag
,
val
,
self
.
task_logger
.
report_scalar
(
tag
,
tag
,
val
,
...
...
mmcv/runner/hooks/logger/dvclive.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
..hook
import
HOOKS
...
@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook):
...
@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
model_file
=
None
,
model_file
:
Optional
[
str
]
=
None
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
model_file
=
model_file
self
.
model_file
=
model_file
self
.
import_dvclive
(
**
kwargs
)
self
.
import_dvclive
(
**
kwargs
)
def
import_dvclive
(
self
,
**
kwargs
):
def
import_dvclive
(
self
,
**
kwargs
)
->
None
:
try
:
try
:
from
dvclive
import
Live
from
dvclive
import
Live
except
ImportError
:
except
ImportError
:
...
@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook):
...
@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook):
self
.
dvclive
=
Live
(
**
kwargs
)
self
.
dvclive
=
Live
(
**
kwargs
)
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
tags
:
self
.
dvclive
.
set_step
(
self
.
get_iter
(
runner
))
self
.
dvclive
.
set_step
(
self
.
get_iter
(
runner
))
...
@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook):
...
@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook):
self
.
dvclive
.
log
(
k
,
v
)
self
.
dvclive
.
log
(
k
,
v
)
@
master_only
@
master_only
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
)
->
None
:
super
().
after_train_epoch
(
runner
)
super
().
after_train_epoch
(
runner
)
if
self
.
model_file
is
not
None
:
if
self
.
model_file
is
not
None
:
runner
.
save_checkpoint
(
runner
.
save_checkpoint
(
...
...
mmcv/runner/hooks/logger/mlflow.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
mmcv.utils
import
TORCH_VERSION
from
mmcv.utils
import
TORCH_VERSION
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
..hook
import
HOOKS
...
@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook):
...
@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
exp_name
=
None
,
exp_name
:
Optional
[
str
]
=
None
,
tags
=
None
,
tags
:
Optional
[
Dict
]
=
None
,
log_model
=
True
,
log_model
:
bool
=
True
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_mlflow
()
self
.
import_mlflow
()
self
.
exp_name
=
exp_name
self
.
exp_name
=
exp_name
self
.
tags
=
tags
self
.
tags
=
tags
self
.
log_model
=
log_model
self
.
log_model
=
log_model
def
import_mlflow
(
self
):
def
import_mlflow
(
self
)
->
None
:
try
:
try
:
import
mlflow
import
mlflow
import
mlflow.pytorch
as
mlflow_pytorch
import
mlflow.pytorch
as
mlflow_pytorch
...
@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook):
...
@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook):
self
.
mlflow_pytorch
=
mlflow_pytorch
self
.
mlflow_pytorch
=
mlflow_pytorch
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
if
self
.
exp_name
is
not
None
:
if
self
.
exp_name
is
not
None
:
self
.
mlflow
.
set_experiment
(
self
.
exp_name
)
self
.
mlflow
.
set_experiment
(
self
.
exp_name
)
...
@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook):
...
@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook):
self
.
mlflow
.
set_tags
(
self
.
tags
)
self
.
mlflow
.
set_tags
(
self
.
tags
)
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
tags
:
self
.
mlflow
.
log_metrics
(
tags
,
step
=
self
.
get_iter
(
runner
))
self
.
mlflow
.
log_metrics
(
tags
,
step
=
self
.
get_iter
(
runner
))
@
master_only
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
log_model
:
if
self
.
log_model
:
self
.
mlflow_pytorch
.
log_model
(
self
.
mlflow_pytorch
.
log_model
(
runner
.
model
,
runner
.
model
,
...
...
mmcv/runner/hooks/logger/neptune.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
...
@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook):
...
@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_kwargs
=
None
,
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
True
,
reset_flag
:
bool
=
True
,
with_step
=
True
,
with_step
:
bool
=
True
,
by_epoch
=
True
):
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_neptune
()
self
.
import_neptune
()
self
.
init_kwargs
=
init_kwargs
self
.
init_kwargs
=
init_kwargs
self
.
with_step
=
with_step
self
.
with_step
=
with_step
def
import_neptune
(
self
):
def
import_neptune
(
self
)
->
None
:
try
:
try
:
import
neptune.new
as
neptune
import
neptune.new
as
neptune
except
ImportError
:
except
ImportError
:
...
@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook):
...
@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook):
self
.
run
=
None
self
.
run
=
None
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
if
self
.
init_kwargs
:
if
self
.
init_kwargs
:
self
.
run
=
self
.
neptune
.
init
(
**
self
.
init_kwargs
)
self
.
run
=
self
.
neptune
.
init
(
**
self
.
init_kwargs
)
else
:
else
:
self
.
run
=
self
.
neptune
.
init
()
self
.
run
=
self
.
neptune
.
init
()
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
tags
:
for
tag_name
,
tag_value
in
tags
.
items
():
for
tag_name
,
tag_value
in
tags
.
items
():
if
self
.
with_step
:
if
self
.
with_step
:
self
.
run
[
tag_name
].
log
(
self
.
run
[
tag_name
].
log
(
# type: ignore
tag_value
,
step
=
self
.
get_iter
(
runner
))
tag_value
,
step
=
self
.
get_iter
(
runner
))
else
:
else
:
tags
[
'global_step'
]
=
self
.
get_iter
(
runner
)
tags
[
'global_step'
]
=
self
.
get_iter
(
runner
)
self
.
run
[
tag_name
].
log
(
tags
)
self
.
run
[
tag_name
].
log
(
tags
)
# type: ignore
@
master_only
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
self
.
run
.
stop
()
self
.
run
.
stop
()
# type: ignore
mmcv/runner/hooks/logger/pavi.py
View file @
ea173c9f
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
json
import
json
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Dict
,
Optional
import
torch
import
torch
import
yaml
import
yaml
...
@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook):
...
@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_kwargs
=
None
,
init_kwargs
:
Optional
[
Dict
]
=
None
,
add_graph
=
False
,
add_graph
:
bool
=
False
,
add_last_ckpt
=
False
,
add_last_ckpt
:
bool
=
False
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
img_key
=
'img_info'
):
img_key
:
str
=
'img_info'
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
init_kwargs
=
init_kwargs
self
.
init_kwargs
=
init_kwargs
self
.
add_graph
=
add_graph
self
.
add_graph
=
add_graph
...
@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
...
@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
self
.
img_key
=
img_key
self
.
img_key
=
img_key
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
try
:
try
:
from
pavi
import
SummaryWriter
from
pavi
import
SummaryWriter
...
@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook):
...
@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook):
self
.
init_kwargs
[
'session_text'
]
=
session_text
self
.
init_kwargs
[
'session_text'
]
=
session_text
self
.
writer
=
SummaryWriter
(
**
self
.
init_kwargs
)
self
.
writer
=
SummaryWriter
(
**
self
.
init_kwargs
)
def
get_step
(
self
,
runner
):
def
get_step
(
self
,
runner
)
->
int
:
"""Get the total training step/epoch."""
"""Get the total training step/epoch."""
if
self
.
get_mode
(
runner
)
==
'val'
and
self
.
by_epoch
:
if
self
.
get_mode
(
runner
)
==
'val'
and
self
.
by_epoch
:
return
self
.
get_epoch
(
runner
)
return
self
.
get_epoch
(
runner
)
...
@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook):
...
@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook):
return
self
.
get_iter
(
runner
)
return
self
.
get_iter
(
runner
)
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
,
add_mode
=
False
)
tags
=
self
.
get_loggable_tags
(
runner
,
add_mode
=
False
)
if
tags
:
if
tags
:
self
.
writer
.
add_scalars
(
self
.
writer
.
add_scalars
(
self
.
get_mode
(
runner
),
tags
,
self
.
get_step
(
runner
))
self
.
get_mode
(
runner
),
tags
,
self
.
get_step
(
runner
))
@
master_only
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
add_last_ckpt
:
if
self
.
add_last_ckpt
:
ckpt_path
=
osp
.
join
(
runner
.
work_dir
,
'latest.pth'
)
ckpt_path
=
osp
.
join
(
runner
.
work_dir
,
'latest.pth'
)
if
osp
.
islink
(
ckpt_path
):
if
osp
.
islink
(
ckpt_path
):
...
@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook):
...
@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook):
self
.
writer
.
close
()
self
.
writer
.
close
()
@
master_only
@
master_only
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
)
->
None
:
if
runner
.
epoch
==
0
and
self
.
add_graph
:
if
runner
.
epoch
==
0
and
self
.
add_graph
:
if
is_module_wrapper
(
runner
.
model
):
if
is_module_wrapper
(
runner
.
model
):
_model
=
runner
.
model
.
module
_model
=
runner
.
model
.
module
...
...
mmcv/runner/hooks/logger/segmind.py
View file @
ea173c9f
...
@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook):
...
@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
by_epoch
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_segmind
()
self
.
import_segmind
()
def
import_segmind
(
self
):
def
import_segmind
(
self
)
->
None
:
try
:
try
:
import
segmind
import
segmind
except
ImportError
:
except
ImportError
:
...
@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook):
...
@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook):
self
.
mlflow_log
=
segmind
.
utils
.
logging_utils
.
try_mlflow_log
self
.
mlflow_log
=
segmind
.
utils
.
logging_utils
.
try_mlflow_log
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
tags
:
# logging metrics to segmind
# logging metrics to segmind
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
...
@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook):
...
@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
log_dir
=
None
,
log_dir
:
Optional
[
str
]
=
None
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
log_dir
=
log_dir
self
.
log_dir
=
log_dir
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
if
(
TORCH_VERSION
==
'parrots'
if
(
TORCH_VERSION
==
'parrots'
or
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.1'
)):
or
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.1'
)):
...
@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook):
...
@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook):
self
.
writer
=
SummaryWriter
(
self
.
log_dir
)
self
.
writer
=
SummaryWriter
(
self
.
log_dir
)
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
,
allow_text
=
True
)
tags
=
self
.
get_loggable_tags
(
runner
,
allow_text
=
True
)
for
tag
,
val
in
tags
.
items
():
for
tag
,
val
in
tags
.
items
():
if
isinstance
(
val
,
str
):
if
isinstance
(
val
,
str
):
...
@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook):
...
@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook):
self
.
writer
.
add_scalar
(
tag
,
val
,
self
.
get_iter
(
runner
))
self
.
writer
.
add_scalar
(
tag
,
val
,
self
.
get_iter
(
runner
))
@
master_only
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
self
.
writer
.
close
()
self
.
writer
.
close
()
mmcv/runner/hooks/logger/text.py
View file @
ea173c9f
...
@@ -3,6 +3,7 @@ import datetime
...
@@ -3,6 +3,7 @@ import datetime
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook):
...
@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
interval_exp_name
=
1000
,
interval_exp_name
:
int
=
1000
,
out_dir
=
None
,
out_dir
:
Optional
[
str
]
=
None
,
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
),
out_suffix
:
Union
[
str
,
tuple
]
=
(
'.log.json'
,
'.log'
,
'.py'
),
keep_local
=
True
,
keep_local
:
bool
=
True
,
file_client_args
=
None
):
file_client_args
:
Optional
[
Dict
]
=
None
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
by_epoch
=
by_epoch
self
.
by_epoch
=
by_epoch
self
.
time_sec_tot
=
0
self
.
time_sec_tot
=
0
...
@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self
.
file_client
=
FileClient
.
infer_client
(
file_client_args
,
self
.
file_client
=
FileClient
.
infer_client
(
file_client_args
,
self
.
out_dir
)
self
.
out_dir
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
if
self
.
out_dir
is
not
None
:
if
self
.
out_dir
is
not
None
:
...
@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook):
if
runner
.
meta
is
not
None
:
if
runner
.
meta
is
not
None
:
self
.
_dump_log
(
runner
.
meta
,
runner
)
self
.
_dump_log
(
runner
.
meta
,
runner
)
def
_get_max_memory
(
self
,
runner
):
def
_get_max_memory
(
self
,
runner
)
->
int
:
device
=
getattr
(
runner
.
model
,
'output_device'
,
None
)
device
=
getattr
(
runner
.
model
,
'output_device'
,
None
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
device
=
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
device
=
device
)
mem_mb
=
torch
.
tensor
([
int
(
mem
)
//
(
1024
*
1024
)],
mem_mb
=
torch
.
tensor
([
int
(
mem
)
//
(
1024
*
1024
)],
...
@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook):
dist
.
reduce
(
mem_mb
,
0
,
op
=
dist
.
ReduceOp
.
MAX
)
dist
.
reduce
(
mem_mb
,
0
,
op
=
dist
.
ReduceOp
.
MAX
)
return
mem_mb
.
item
()
return
mem_mb
.
item
()
def
_log_info
(
self
,
log_dict
,
runner
):
def
_log_info
(
self
,
log_dict
:
Dict
,
runner
)
->
None
:
# print exp name for users to distinguish experiments
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
# at every ``interval_exp_name`` iterations and the end of each epoch
if
runner
.
meta
is
not
None
and
'exp_name'
in
runner
.
meta
:
if
runner
.
meta
is
not
None
and
'exp_name'
in
runner
.
meta
:
...
@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook):
...
@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook):
lr_str
=
[]
lr_str
=
[]
for
k
,
val
in
log_dict
[
'lr'
].
items
():
for
k
,
val
in
log_dict
[
'lr'
].
items
():
lr_str
.
append
(
f
'lr_
{
k
}
:
{
val
:.
3
e
}
'
)
lr_str
.
append
(
f
'lr_
{
k
}
:
{
val
:.
3
e
}
'
)
lr_str
=
' '
.
join
(
lr_str
)
lr_str
=
' '
.
join
(
lr_str
)
# type: ignore
else
:
else
:
lr_str
=
f
'lr:
{
log_dict
[
"lr"
]:.
3
e
}
'
lr_str
=
f
'lr:
{
log_dict
[
"lr"
]:.
3
e
}
'
# type: ignore
# by epoch: Epoch [4][100/1000]
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
# by iter: Iter [100/100000]
...
@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook):
runner
.
logger
.
info
(
log_str
)
runner
.
logger
.
info
(
log_str
)
def
_dump_log
(
self
,
log_dict
,
runner
):
def
_dump_log
(
self
,
log_dict
:
Dict
,
runner
)
->
None
:
# dump log in json format
# dump log in json format
json_log
=
OrderedDict
()
json_log
=
OrderedDict
()
for
k
,
v
in
log_dict
.
items
():
for
k
,
v
in
log_dict
.
items
():
...
@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook):
else
:
else
:
return
items
return
items
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
OrderedDict
:
if
'eval_iter_num'
in
runner
.
log_buffer
.
output
:
if
'eval_iter_num'
in
runner
.
log_buffer
.
output
:
# this doesn't modify runner.iter and is regardless of by_epoch
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter
=
runner
.
log_buffer
.
output
.
pop
(
'eval_iter_num'
)
cur_iter
=
runner
.
log_buffer
.
output
.
pop
(
'eval_iter_num'
)
...
@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook):
...
@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
log_dict
[
'memory'
]
=
self
.
_get_max_memory
(
runner
)
log_dict
[
'memory'
]
=
self
.
_get_max_memory
(
runner
)
log_dict
=
dict
(
log_dict
,
**
runner
.
log_buffer
.
output
)
log_dict
=
dict
(
log_dict
,
**
runner
.
log_buffer
.
output
)
# type: ignore
self
.
_log_info
(
log_dict
,
runner
)
self
.
_log_info
(
log_dict
,
runner
)
self
.
_dump_log
(
log_dict
,
runner
)
self
.
_dump_log
(
log_dict
,
runner
)
return
log_dict
return
log_dict
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
# copy or upload logs to self.out_dir
# copy or upload logs to self.out_dir
if
self
.
out_dir
is
not
None
:
if
self
.
out_dir
is
not
None
:
for
filename
in
scandir
(
runner
.
work_dir
,
self
.
out_suffix
,
True
):
for
filename
in
scandir
(
runner
.
work_dir
,
self
.
out_suffix
,
True
):
...
...
mmcv/runner/hooks/logger/wandb.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Dict
,
Optional
,
Union
from
mmcv.utils
import
scandir
from
mmcv.utils
import
scandir
from
...dist_utils
import
master_only
from
...dist_utils
import
master_only
...
@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook):
...
@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_kwargs
=
None
,
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
=
10
,
interval
:
int
=
10
,
ignore_last
=
True
,
ignore_last
:
bool
=
True
,
reset_flag
=
False
,
reset_flag
:
bool
=
False
,
commit
=
True
,
commit
:
bool
=
True
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
with_step
=
True
,
with_step
:
bool
=
True
,
log_artifact
=
True
,
log_artifact
:
bool
=
True
,
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
)):
out_suffix
:
Union
[
str
,
tuple
]
=
(
'.log.json'
,
'.log'
,
'.py'
)):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_wandb
()
self
.
import_wandb
()
self
.
init_kwargs
=
init_kwargs
self
.
init_kwargs
=
init_kwargs
...
@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook):
...
@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook):
self
.
log_artifact
=
log_artifact
self
.
log_artifact
=
log_artifact
self
.
out_suffix
=
out_suffix
self
.
out_suffix
=
out_suffix
def
import_wandb
(
self
):
def
import_wandb
(
self
)
->
None
:
try
:
try
:
import
wandb
import
wandb
except
ImportError
:
except
ImportError
:
...
@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook):
...
@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook):
self
.
wandb
=
wandb
self
.
wandb
=
wandb
@
master_only
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
if
self
.
wandb
is
None
:
if
self
.
wandb
is
None
:
self
.
import_wandb
()
self
.
import_wandb
()
if
self
.
init_kwargs
:
if
self
.
init_kwargs
:
self
.
wandb
.
init
(
**
self
.
init_kwargs
)
self
.
wandb
.
init
(
**
self
.
init_kwargs
)
# type: ignore
else
:
else
:
self
.
wandb
.
init
()
self
.
wandb
.
init
()
# type: ignore
@
master_only
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
tags
:
if
self
.
with_step
:
if
self
.
with_step
:
...
@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook):
...
@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook):
self
.
wandb
.
log
(
tags
,
commit
=
self
.
commit
)
self
.
wandb
.
log
(
tags
,
commit
=
self
.
commit
)
@
master_only
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
log_artifact
:
if
self
.
log_artifact
:
wandb_artifact
=
self
.
wandb
.
Artifact
(
wandb_artifact
=
self
.
wandb
.
Artifact
(
name
=
'artifacts'
,
type
=
'model'
)
name
=
'artifacts'
,
type
=
'model'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment