Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ea173c9f
Unverified
Commit
ea173c9f
authored
May 30, 2022
by
tripleMu
Committed by
GitHub
May 30, 2022
Browse files
Add type hints for mmcv/runer/hooks/logger (#2000)
* Fix * Fix
parent
c70fafeb
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
135 additions
and
119 deletions
+135
-119
mmcv/runner/hooks/logger/base.py
mmcv/runner/hooks/logger/base.py
+26
-21
mmcv/runner/hooks/logger/clearml.py
mmcv/runner/hooks/logger/clearml.py
+9
-7
mmcv/runner/hooks/logger/dvclive.py
mmcv/runner/hooks/logger/dvclive.py
+9
-8
mmcv/runner/hooks/logger/mlflow.py
mmcv/runner/hooks/logger/mlflow.py
+13
-11
mmcv/runner/hooks/logger/neptune.py
mmcv/runner/hooks/logger/neptune.py
+15
-13
mmcv/runner/hooks/logger/pavi.py
mmcv/runner/hooks/logger/pavi.py
+14
-13
mmcv/runner/hooks/logger/segmind.py
mmcv/runner/hooks/logger/segmind.py
+5
-5
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+9
-8
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+19
-18
mmcv/runner/hooks/logger/wandb.py
mmcv/runner/hooks/logger/wandb.py
+16
-15
No files found.
mmcv/runner/hooks/logger/base.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
import
numbers
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
Dict
import
numpy
as
np
import
torch
...
...
@@ -23,10 +24,10 @@ class LoggerHook(Hook):
__metaclass__
=
ABCMeta
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
):
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
):
self
.
interval
=
interval
self
.
ignore_last
=
ignore_last
self
.
reset_flag
=
reset_flag
...
...
@@ -37,7 +38,9 @@ class LoggerHook(Hook):
pass
@
staticmethod
def
is_scalar
(
val
,
include_np
=
True
,
include_torch
=
True
):
def
is_scalar
(
val
,
include_np
:
bool
=
True
,
include_torch
:
bool
=
True
)
->
bool
:
"""Tell the input variable is a scalar or not.
Args:
...
...
@@ -57,7 +60,7 @@ class LoggerHook(Hook):
else
:
return
False
def
get_mode
(
self
,
runner
):
def
get_mode
(
self
,
runner
)
->
str
:
if
runner
.
mode
==
'train'
:
if
'time'
in
runner
.
log_buffer
.
output
:
mode
=
'train'
...
...
@@ -70,7 +73,7 @@ class LoggerHook(Hook):
f
'but got
{
runner
.
mode
}
'
)
return
mode
def
get_epoch
(
self
,
runner
):
def
get_epoch
(
self
,
runner
)
->
int
:
if
runner
.
mode
==
'train'
:
epoch
=
runner
.
epoch
+
1
elif
runner
.
mode
==
'val'
:
...
...
@@ -82,7 +85,7 @@ class LoggerHook(Hook):
f
'but got
{
runner
.
mode
}
'
)
return
epoch
def
get_iter
(
self
,
runner
,
inner_iter
=
False
)
:
def
get_iter
(
self
,
runner
,
inner_iter
:
bool
=
False
)
->
int
:
"""Get the current training iteration step."""
if
self
.
by_epoch
and
inner_iter
:
current_iter
=
runner
.
inner_iter
+
1
...
...
@@ -90,7 +93,7 @@ class LoggerHook(Hook):
current_iter
=
runner
.
iter
+
1
return
current_iter
def
get_lr_tags
(
self
,
runner
):
def
get_lr_tags
(
self
,
runner
)
->
Dict
[
str
,
float
]
:
tags
=
{}
lrs
=
runner
.
current_lr
()
if
isinstance
(
lrs
,
dict
):
...
...
@@ -100,7 +103,7 @@ class LoggerHook(Hook):
tags
[
'learning_rate'
]
=
lrs
[
0
]
return
tags
def
get_momentum_tags
(
self
,
runner
):
def
get_momentum_tags
(
self
,
runner
)
->
Dict
[
str
,
float
]
:
tags
=
{}
momentums
=
runner
.
current_momentum
()
if
isinstance
(
momentums
,
dict
):
...
...
@@ -110,12 +113,14 @@ class LoggerHook(Hook):
tags
[
'momentum'
]
=
momentums
[
0
]
return
tags
def
get_loggable_tags
(
self
,
runner
,
allow_scalar
=
True
,
allow_text
=
False
,
add_mode
=
True
,
tags_to_skip
=
(
'time'
,
'data_time'
)):
def
get_loggable_tags
(
self
,
runner
,
allow_scalar
:
bool
=
True
,
allow_text
:
bool
=
False
,
add_mode
:
bool
=
True
,
tags_to_skip
:
tuple
=
(
'time'
,
'data_time'
)
)
->
Dict
:
tags
=
{}
for
var
,
val
in
runner
.
log_buffer
.
output
.
items
():
if
var
in
tags_to_skip
:
...
...
@@ -131,16 +136,16 @@ class LoggerHook(Hook):
tags
.
update
(
self
.
get_momentum_tags
(
runner
))
return
tags
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
for
hook
in
runner
.
hooks
[::
-
1
]:
if
isinstance
(
hook
,
LoggerHook
):
hook
.
reset_flag
=
True
break
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
)
->
None
:
runner
.
log_buffer
.
clear
()
# clear logs of last epoch
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
)
->
None
:
if
self
.
by_epoch
and
self
.
every_n_inner_iters
(
runner
,
self
.
interval
):
runner
.
log_buffer
.
average
(
self
.
interval
)
elif
not
self
.
by_epoch
and
self
.
every_n_iters
(
runner
,
self
.
interval
):
...
...
@@ -154,13 +159,13 @@ class LoggerHook(Hook):
if
self
.
reset_flag
:
runner
.
log_buffer
.
clear_output
()
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
)
->
None
:
if
runner
.
log_buffer
.
ready
:
self
.
log
(
runner
)
if
self
.
reset_flag
:
runner
.
log_buffer
.
clear_output
()
def
after_val_epoch
(
self
,
runner
):
def
after_val_epoch
(
self
,
runner
)
->
None
:
runner
.
log_buffer
.
average
()
self
.
log
(
runner
)
if
self
.
reset_flag
:
...
...
mmcv/runner/hooks/logger/clearml.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
...
...
@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
init_kwargs
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
):
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_clearml
()
self
.
init_kwargs
=
init_kwargs
...
...
@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook):
self
.
clearml
=
clearml
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
task_kwargs
=
self
.
init_kwargs
if
self
.
init_kwargs
else
{}
self
.
task
=
self
.
clearml
.
Task
.
init
(
**
task_kwargs
)
self
.
task_logger
=
self
.
task
.
get_logger
()
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
for
tag
,
val
in
tags
.
items
():
self
.
task_logger
.
report_scalar
(
tag
,
tag
,
val
,
...
...
mmcv/runner/hooks/logger/dvclive.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
from
pathlib
import
Path
from
typing
import
Optional
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
...
...
@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
model_file
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
,
model_file
:
Optional
[
str
]
=
None
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
,
**
kwargs
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
model_file
=
model_file
self
.
import_dvclive
(
**
kwargs
)
def
import_dvclive
(
self
,
**
kwargs
):
def
import_dvclive
(
self
,
**
kwargs
)
->
None
:
try
:
from
dvclive
import
Live
except
ImportError
:
...
...
@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook):
self
.
dvclive
=
Live
(
**
kwargs
)
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
self
.
dvclive
.
set_step
(
self
.
get_iter
(
runner
))
...
...
@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook):
self
.
dvclive
.
log
(
k
,
v
)
@
master_only
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
)
->
None
:
super
().
after_train_epoch
(
runner
)
if
self
.
model_file
is
not
None
:
runner
.
save_checkpoint
(
...
...
mmcv/runner/hooks/logger/mlflow.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
mmcv.utils
import
TORCH_VERSION
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
...
...
@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
exp_name
=
None
,
tags
=
None
,
log_model
=
True
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
):
exp_name
:
Optional
[
str
]
=
None
,
tags
:
Optional
[
Dict
]
=
None
,
log_model
:
bool
=
True
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_mlflow
()
self
.
exp_name
=
exp_name
self
.
tags
=
tags
self
.
log_model
=
log_model
def
import_mlflow
(
self
):
def
import_mlflow
(
self
)
->
None
:
try
:
import
mlflow
import
mlflow.pytorch
as
mlflow_pytorch
...
...
@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook):
self
.
mlflow_pytorch
=
mlflow_pytorch
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
if
self
.
exp_name
is
not
None
:
self
.
mlflow
.
set_experiment
(
self
.
exp_name
)
...
...
@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook):
self
.
mlflow
.
set_tags
(
self
.
tags
)
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
self
.
mlflow
.
log_metrics
(
tags
,
step
=
self
.
get_iter
(
runner
))
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
log_model
:
self
.
mlflow_pytorch
.
log_model
(
runner
.
model
,
...
...
mmcv/runner/hooks/logger/neptune.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
Optional
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
...
...
@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
init_kwargs
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
True
,
with_step
=
True
,
by_epoch
=
True
):
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
True
,
with_step
:
bool
=
True
,
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_neptune
()
self
.
init_kwargs
=
init_kwargs
self
.
with_step
=
with_step
def
import_neptune
(
self
):
def
import_neptune
(
self
)
->
None
:
try
:
import
neptune.new
as
neptune
except
ImportError
:
...
...
@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook):
self
.
run
=
None
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
if
self
.
init_kwargs
:
self
.
run
=
self
.
neptune
.
init
(
**
self
.
init_kwargs
)
else
:
self
.
run
=
self
.
neptune
.
init
()
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
for
tag_name
,
tag_value
in
tags
.
items
():
if
self
.
with_step
:
self
.
run
[
tag_name
].
log
(
self
.
run
[
tag_name
].
log
(
# type: ignore
tag_value
,
step
=
self
.
get_iter
(
runner
))
else
:
tags
[
'global_step'
]
=
self
.
get_iter
(
runner
)
self
.
run
[
tag_name
].
log
(
tags
)
self
.
run
[
tag_name
].
log
(
tags
)
# type: ignore
@
master_only
def
after_run
(
self
,
runner
):
self
.
run
.
stop
()
def
after_run
(
self
,
runner
)
->
None
:
self
.
run
.
stop
()
# type: ignore
mmcv/runner/hooks/logger/pavi.py
View file @
ea173c9f
...
...
@@ -2,6 +2,7 @@
import
json
import
os
import
os.path
as
osp
from
typing
import
Dict
,
Optional
import
torch
import
yaml
...
...
@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
init_kwargs
=
None
,
add_graph
=
False
,
add_last_ckpt
=
False
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
,
img_key
=
'img_info'
):
init_kwargs
:
Optional
[
Dict
]
=
None
,
add_graph
:
bool
=
False
,
add_last_ckpt
:
bool
=
False
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
,
img_key
:
str
=
'img_info'
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
init_kwargs
=
init_kwargs
self
.
add_graph
=
add_graph
...
...
@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
self
.
img_key
=
img_key
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
try
:
from
pavi
import
SummaryWriter
...
...
@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook):
self
.
init_kwargs
[
'session_text'
]
=
session_text
self
.
writer
=
SummaryWriter
(
**
self
.
init_kwargs
)
def
get_step
(
self
,
runner
):
def
get_step
(
self
,
runner
)
->
int
:
"""Get the total training step/epoch."""
if
self
.
get_mode
(
runner
)
==
'val'
and
self
.
by_epoch
:
return
self
.
get_epoch
(
runner
)
...
...
@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook):
return
self
.
get_iter
(
runner
)
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
,
add_mode
=
False
)
if
tags
:
self
.
writer
.
add_scalars
(
self
.
get_mode
(
runner
),
tags
,
self
.
get_step
(
runner
))
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
add_last_ckpt
:
ckpt_path
=
osp
.
join
(
runner
.
work_dir
,
'latest.pth'
)
if
osp
.
islink
(
ckpt_path
):
...
...
@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook):
self
.
writer
.
close
()
@
master_only
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
)
->
None
:
if
runner
.
epoch
==
0
and
self
.
add_graph
:
if
is_module_wrapper
(
runner
.
model
):
_model
=
runner
.
model
.
module
...
...
mmcv/runner/hooks/logger/segmind.py
View file @
ea173c9f
...
...
@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_segmind
()
def
import_segmind
(
self
):
def
import_segmind
(
self
)
->
None
:
try
:
import
segmind
except
ImportError
:
...
...
@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook):
self
.
mlflow_log
=
segmind
.
utils
.
logging_utils
.
try_mlflow_log
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
# logging metrics to segmind
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
typing
import
Optional
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
...dist_utils
import
master_only
...
...
@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
log_dir
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
by_epoch
=
True
):
log_dir
:
Optional
[
str
]
=
None
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
by_epoch
:
bool
=
True
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
log_dir
=
log_dir
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
if
(
TORCH_VERSION
==
'parrots'
or
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.1'
)):
...
...
@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook):
self
.
writer
=
SummaryWriter
(
self
.
log_dir
)
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
,
allow_text
=
True
)
for
tag
,
val
in
tags
.
items
():
if
isinstance
(
val
,
str
):
...
...
@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook):
self
.
writer
.
add_scalar
(
tag
,
val
,
self
.
get_iter
(
runner
))
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
self
.
writer
.
close
()
mmcv/runner/hooks/logger/text.py
View file @
ea173c9f
...
...
@@ -3,6 +3,7 @@ import datetime
import
os
import
os.path
as
osp
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
by_epoch
=
True
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
interval_exp_name
=
1000
,
out_dir
=
None
,
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
),
keep_local
=
True
,
file_client_args
=
None
):
by_epoch
:
bool
=
True
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
interval_exp_name
:
int
=
1000
,
out_dir
:
Optional
[
str
]
=
None
,
out_suffix
:
Union
[
str
,
tuple
]
=
(
'.log.json'
,
'.log'
,
'.py'
),
keep_local
:
bool
=
True
,
file_client_args
:
Optional
[
Dict
]
=
None
):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
by_epoch
=
by_epoch
self
.
time_sec_tot
=
0
...
...
@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self
.
file_client
=
FileClient
.
infer_client
(
file_client_args
,
self
.
out_dir
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
if
self
.
out_dir
is
not
None
:
...
...
@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook):
if
runner
.
meta
is
not
None
:
self
.
_dump_log
(
runner
.
meta
,
runner
)
def
_get_max_memory
(
self
,
runner
):
def
_get_max_memory
(
self
,
runner
)
->
int
:
device
=
getattr
(
runner
.
model
,
'output_device'
,
None
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
device
=
device
)
mem_mb
=
torch
.
tensor
([
int
(
mem
)
//
(
1024
*
1024
)],
...
...
@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook):
dist
.
reduce
(
mem_mb
,
0
,
op
=
dist
.
ReduceOp
.
MAX
)
return
mem_mb
.
item
()
def
_log_info
(
self
,
log_dict
,
runner
):
def
_log_info
(
self
,
log_dict
:
Dict
,
runner
)
->
None
:
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if
runner
.
meta
is
not
None
and
'exp_name'
in
runner
.
meta
:
...
...
@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook):
lr_str
=
[]
for
k
,
val
in
log_dict
[
'lr'
].
items
():
lr_str
.
append
(
f
'lr_
{
k
}
:
{
val
:.
3
e
}
'
)
lr_str
=
' '
.
join
(
lr_str
)
lr_str
=
' '
.
join
(
lr_str
)
# type: ignore
else
:
lr_str
=
f
'lr:
{
log_dict
[
"lr"
]:.
3
e
}
'
lr_str
=
f
'lr:
{
log_dict
[
"lr"
]:.
3
e
}
'
# type: ignore
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
...
...
@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook):
runner
.
logger
.
info
(
log_str
)
def
_dump_log
(
self
,
log_dict
,
runner
):
def
_dump_log
(
self
,
log_dict
:
Dict
,
runner
)
->
None
:
# dump log in json format
json_log
=
OrderedDict
()
for
k
,
v
in
log_dict
.
items
():
...
...
@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook):
else
:
return
items
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
OrderedDict
:
if
'eval_iter_num'
in
runner
.
log_buffer
.
output
:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter
=
runner
.
log_buffer
.
output
.
pop
(
'eval_iter_num'
)
...
...
@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook):
if
torch
.
cuda
.
is_available
():
log_dict
[
'memory'
]
=
self
.
_get_max_memory
(
runner
)
log_dict
=
dict
(
log_dict
,
**
runner
.
log_buffer
.
output
)
log_dict
=
dict
(
log_dict
,
**
runner
.
log_buffer
.
output
)
# type: ignore
self
.
_log_info
(
log_dict
,
runner
)
self
.
_dump_log
(
log_dict
,
runner
)
return
log_dict
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
# copy or upload logs to self.out_dir
if
self
.
out_dir
is
not
None
:
for
filename
in
scandir
(
runner
.
work_dir
,
self
.
out_suffix
,
True
):
...
...
mmcv/runner/hooks/logger/wandb.py
View file @
ea173c9f
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
typing
import
Dict
,
Optional
,
Union
from
mmcv.utils
import
scandir
from
...dist_utils
import
master_only
...
...
@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook):
"""
def
__init__
(
self
,
init_kwargs
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
,
commit
=
True
,
by_epoch
=
True
,
with_step
=
True
,
log_artifact
=
True
,
out_suffix
=
(
'.log.json'
,
'.log'
,
'.py'
)):
init_kwargs
:
Optional
[
Dict
]
=
None
,
interval
:
int
=
10
,
ignore_last
:
bool
=
True
,
reset_flag
:
bool
=
False
,
commit
:
bool
=
True
,
by_epoch
:
bool
=
True
,
with_step
:
bool
=
True
,
log_artifact
:
bool
=
True
,
out_suffix
:
Union
[
str
,
tuple
]
=
(
'.log.json'
,
'.log'
,
'.py'
)):
super
().
__init__
(
interval
,
ignore_last
,
reset_flag
,
by_epoch
)
self
.
import_wandb
()
self
.
init_kwargs
=
init_kwargs
...
...
@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook):
self
.
log_artifact
=
log_artifact
self
.
out_suffix
=
out_suffix
def
import_wandb
(
self
):
def
import_wandb
(
self
)
->
None
:
try
:
import
wandb
except
ImportError
:
...
...
@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook):
self
.
wandb
=
wandb
@
master_only
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
)
->
None
:
super
().
before_run
(
runner
)
if
self
.
wandb
is
None
:
self
.
import_wandb
()
if
self
.
init_kwargs
:
self
.
wandb
.
init
(
**
self
.
init_kwargs
)
self
.
wandb
.
init
(
**
self
.
init_kwargs
)
# type: ignore
else
:
self
.
wandb
.
init
()
self
.
wandb
.
init
()
# type: ignore
@
master_only
def
log
(
self
,
runner
):
def
log
(
self
,
runner
)
->
None
:
tags
=
self
.
get_loggable_tags
(
runner
)
if
tags
:
if
self
.
with_step
:
...
...
@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook):
self
.
wandb
.
log
(
tags
,
commit
=
self
.
commit
)
@
master_only
def
after_run
(
self
,
runner
):
def
after_run
(
self
,
runner
)
->
None
:
if
self
.
log_artifact
:
wandb_artifact
=
self
.
wandb
.
Artifact
(
name
=
'artifacts'
,
type
=
'model'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment