Unverified Commit ea173c9f authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/runer/hooks/logger (#2000)

* Fix

* Fix
parent c70fafeb
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numbers import numbers
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Dict
import numpy as np import numpy as np
import torch import torch
...@@ -23,10 +24,10 @@ class LoggerHook(Hook): ...@@ -23,10 +24,10 @@ class LoggerHook(Hook):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self, def __init__(self,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True): by_epoch: bool = True):
self.interval = interval self.interval = interval
self.ignore_last = ignore_last self.ignore_last = ignore_last
self.reset_flag = reset_flag self.reset_flag = reset_flag
...@@ -37,7 +38,9 @@ class LoggerHook(Hook): ...@@ -37,7 +38,9 @@ class LoggerHook(Hook):
pass pass
@staticmethod @staticmethod
def is_scalar(val, include_np=True, include_torch=True): def is_scalar(val,
include_np: bool = True,
include_torch: bool = True) -> bool:
"""Tell the input variable is a scalar or not. """Tell the input variable is a scalar or not.
Args: Args:
...@@ -57,7 +60,7 @@ class LoggerHook(Hook): ...@@ -57,7 +60,7 @@ class LoggerHook(Hook):
else: else:
return False return False
def get_mode(self, runner): def get_mode(self, runner) -> str:
if runner.mode == 'train': if runner.mode == 'train':
if 'time' in runner.log_buffer.output: if 'time' in runner.log_buffer.output:
mode = 'train' mode = 'train'
...@@ -70,7 +73,7 @@ class LoggerHook(Hook): ...@@ -70,7 +73,7 @@ class LoggerHook(Hook):
f'but got {runner.mode}') f'but got {runner.mode}')
return mode return mode
def get_epoch(self, runner): def get_epoch(self, runner) -> int:
if runner.mode == 'train': if runner.mode == 'train':
epoch = runner.epoch + 1 epoch = runner.epoch + 1
elif runner.mode == 'val': elif runner.mode == 'val':
...@@ -82,7 +85,7 @@ class LoggerHook(Hook): ...@@ -82,7 +85,7 @@ class LoggerHook(Hook):
f'but got {runner.mode}') f'but got {runner.mode}')
return epoch return epoch
def get_iter(self, runner, inner_iter=False): def get_iter(self, runner, inner_iter: bool = False) -> int:
"""Get the current training iteration step.""" """Get the current training iteration step."""
if self.by_epoch and inner_iter: if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1 current_iter = runner.inner_iter + 1
...@@ -90,7 +93,7 @@ class LoggerHook(Hook): ...@@ -90,7 +93,7 @@ class LoggerHook(Hook):
current_iter = runner.iter + 1 current_iter = runner.iter + 1
return current_iter return current_iter
def get_lr_tags(self, runner): def get_lr_tags(self, runner) -> Dict[str, float]:
tags = {} tags = {}
lrs = runner.current_lr() lrs = runner.current_lr()
if isinstance(lrs, dict): if isinstance(lrs, dict):
...@@ -100,7 +103,7 @@ class LoggerHook(Hook): ...@@ -100,7 +103,7 @@ class LoggerHook(Hook):
tags['learning_rate'] = lrs[0] tags['learning_rate'] = lrs[0]
return tags return tags
def get_momentum_tags(self, runner): def get_momentum_tags(self, runner) -> Dict[str, float]:
tags = {} tags = {}
momentums = runner.current_momentum() momentums = runner.current_momentum()
if isinstance(momentums, dict): if isinstance(momentums, dict):
...@@ -110,12 +113,14 @@ class LoggerHook(Hook): ...@@ -110,12 +113,14 @@ class LoggerHook(Hook):
tags['momentum'] = momentums[0] tags['momentum'] = momentums[0]
return tags return tags
def get_loggable_tags(self, def get_loggable_tags(
runner, self,
allow_scalar=True, runner,
allow_text=False, allow_scalar: bool = True,
add_mode=True, allow_text: bool = False,
tags_to_skip=('time', 'data_time')): add_mode: bool = True,
tags_to_skip: tuple = ('time', 'data_time')
) -> Dict:
tags = {} tags = {}
for var, val in runner.log_buffer.output.items(): for var, val in runner.log_buffer.output.items():
if var in tags_to_skip: if var in tags_to_skip:
...@@ -131,16 +136,16 @@ class LoggerHook(Hook): ...@@ -131,16 +136,16 @@ class LoggerHook(Hook):
tags.update(self.get_momentum_tags(runner)) tags.update(self.get_momentum_tags(runner))
return tags return tags
def before_run(self, runner): def before_run(self, runner) -> None:
for hook in runner.hooks[::-1]: for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook): if isinstance(hook, LoggerHook):
hook.reset_flag = True hook.reset_flag = True
break break
def before_epoch(self, runner): def before_epoch(self, runner) -> None:
runner.log_buffer.clear() # clear logs of last epoch runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner): def after_train_iter(self, runner) -> None:
if self.by_epoch and self.every_n_inner_iters(runner, self.interval): if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval) runner.log_buffer.average(self.interval)
elif not self.by_epoch and self.every_n_iters(runner, self.interval): elif not self.by_epoch and self.every_n_iters(runner, self.interval):
...@@ -154,13 +159,13 @@ class LoggerHook(Hook): ...@@ -154,13 +159,13 @@ class LoggerHook(Hook):
if self.reset_flag: if self.reset_flag:
runner.log_buffer.clear_output() runner.log_buffer.clear_output()
def after_train_epoch(self, runner): def after_train_epoch(self, runner) -> None:
if runner.log_buffer.ready: if runner.log_buffer.ready:
self.log(runner) self.log(runner)
if self.reset_flag: if self.reset_flag:
runner.log_buffer.clear_output() runner.log_buffer.clear_output()
def after_val_epoch(self, runner): def after_val_epoch(self, runner) -> None:
runner.log_buffer.average() runner.log_buffer.average()
self.log(runner) self.log(runner)
if self.reset_flag: if self.reset_flag:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook): ...@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
init_kwargs=None, init_kwargs: Optional[Dict] = None,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True): by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_clearml() self.import_clearml()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
...@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook): ...@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook):
self.clearml = clearml self.clearml = clearml
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {} task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs) self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger() self.task_logger = self.task.get_logger()
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
for tag, val in tags.items(): for tag, val in tags.items():
self.task_logger.report_scalar(tag, tag, val, self.task_logger.report_scalar(tag, tag, val,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path from pathlib import Path
from typing import Optional
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
...@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
model_file=None, model_file: Optional[str] = None,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True, by_epoch: bool = True,
**kwargs): **kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file self.model_file = model_file
self.import_dvclive(**kwargs) self.import_dvclive(**kwargs)
def import_dvclive(self, **kwargs): def import_dvclive(self, **kwargs) -> None:
try: try:
from dvclive import Live from dvclive import Live
except ImportError: except ImportError:
...@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook):
self.dvclive = Live(**kwargs) self.dvclive = Live(**kwargs)
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
self.dvclive.set_step(self.get_iter(runner)) self.dvclive.set_step(self.get_iter(runner))
...@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook):
self.dvclive.log(k, v) self.dvclive.log(k, v)
@master_only @master_only
def after_train_epoch(self, runner): def after_train_epoch(self, runner) -> None:
super().after_train_epoch(runner) super().after_train_epoch(runner)
if self.model_file is not None: if self.model_file is not None:
runner.save_checkpoint( runner.save_checkpoint(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
...@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook): ...@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
exp_name=None, exp_name: Optional[str] = None,
tags=None, tags: Optional[Dict] = None,
log_model=True, log_model: bool = True,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True): by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_mlflow() self.import_mlflow()
self.exp_name = exp_name self.exp_name = exp_name
self.tags = tags self.tags = tags
self.log_model = log_model self.log_model = log_model
def import_mlflow(self): def import_mlflow(self) -> None:
try: try:
import mlflow import mlflow
import mlflow.pytorch as mlflow_pytorch import mlflow.pytorch as mlflow_pytorch
...@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow_pytorch = mlflow_pytorch self.mlflow_pytorch = mlflow_pytorch
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
if self.exp_name is not None: if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name) self.mlflow.set_experiment(self.exp_name)
...@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook): ...@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow.set_tags(self.tags) self.mlflow.set_tags(self.tags)
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
self.mlflow.log_metrics(tags, step=self.get_iter(runner)) self.mlflow.log_metrics(tags, step=self.get_iter(runner))
@master_only @master_only
def after_run(self, runner): def after_run(self, runner) -> None:
if self.log_model: if self.log_model:
self.mlflow_pytorch.log_model( self.mlflow_pytorch.log_model(
runner.model, runner.model,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook): ...@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
init_kwargs=None, init_kwargs: Optional[Dict] = None,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=True, reset_flag: bool = True,
with_step=True, with_step: bool = True,
by_epoch=True): by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_neptune() self.import_neptune()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.with_step = with_step self.with_step = with_step
def import_neptune(self): def import_neptune(self) -> None:
try: try:
import neptune.new as neptune import neptune.new as neptune
except ImportError: except ImportError:
...@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook): ...@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook):
self.run = None self.run = None
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
if self.init_kwargs: if self.init_kwargs:
self.run = self.neptune.init(**self.init_kwargs) self.run = self.neptune.init(**self.init_kwargs)
else: else:
self.run = self.neptune.init() self.run = self.neptune.init()
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
for tag_name, tag_value in tags.items(): for tag_name, tag_value in tags.items():
if self.with_step: if self.with_step:
self.run[tag_name].log( self.run[tag_name].log( # type: ignore
tag_value, step=self.get_iter(runner)) tag_value, step=self.get_iter(runner))
else: else:
tags['global_step'] = self.get_iter(runner) tags['global_step'] = self.get_iter(runner)
self.run[tag_name].log(tags) self.run[tag_name].log(tags) # type: ignore
@master_only @master_only
def after_run(self, runner): def after_run(self, runner) -> None:
self.run.stop() self.run.stop() # type: ignore
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import json import json
import os import os
import os.path as osp import os.path as osp
from typing import Dict, Optional
import torch import torch
import yaml import yaml
...@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook): ...@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
init_kwargs=None, init_kwargs: Optional[Dict] = None,
add_graph=False, add_graph: bool = False,
add_last_ckpt=False, add_last_ckpt: bool = False,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True, by_epoch: bool = True,
img_key='img_info'): img_key: str = 'img_info'):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.add_graph = add_graph self.add_graph = add_graph
...@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
self.img_key = img_key self.img_key = img_key
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
try: try:
from pavi import SummaryWriter from pavi import SummaryWriter
...@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook):
self.init_kwargs['session_text'] = session_text self.init_kwargs['session_text'] = session_text
self.writer = SummaryWriter(**self.init_kwargs) self.writer = SummaryWriter(**self.init_kwargs)
def get_step(self, runner): def get_step(self, runner) -> int:
"""Get the total training step/epoch.""" """Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch: if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner) return self.get_epoch(runner)
...@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook): ...@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook):
return self.get_iter(runner) return self.get_iter(runner)
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner, add_mode=False) tags = self.get_loggable_tags(runner, add_mode=False)
if tags: if tags:
self.writer.add_scalars( self.writer.add_scalars(
self.get_mode(runner), tags, self.get_step(runner)) self.get_mode(runner), tags, self.get_step(runner))
@master_only @master_only
def after_run(self, runner): def after_run(self, runner) -> None:
if self.add_last_ckpt: if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth') ckpt_path = osp.join(runner.work_dir, 'latest.pth')
if osp.islink(ckpt_path): if osp.islink(ckpt_path):
...@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook):
self.writer.close() self.writer.close()
@master_only @master_only
def before_epoch(self, runner): def before_epoch(self, runner) -> None:
if runner.epoch == 0 and self.add_graph: if runner.epoch == 0 and self.add_graph:
if is_module_wrapper(runner.model): if is_module_wrapper(runner.model):
_model = runner.model.module _model = runner.model.module
......
...@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook): ...@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True): by_epoch=True):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_segmind() self.import_segmind()
def import_segmind(self): def import_segmind(self) -> None:
try: try:
import segmind import segmind
except ImportError: except ImportError:
...@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook): ...@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook):
self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
# logging metrics to segmind # logging metrics to segmind
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from typing import Optional
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
from ...dist_utils import master_only from ...dist_utils import master_only
...@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
log_dir=None, log_dir: Optional[str] = None,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
by_epoch=True): by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.log_dir = log_dir self.log_dir = log_dir
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
if (TORCH_VERSION == 'parrots' if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')): or digit_version(TORCH_VERSION) < digit_version('1.1')):
...@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook):
self.writer = SummaryWriter(self.log_dir) self.writer = SummaryWriter(self.log_dir)
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner, allow_text=True) tags = self.get_loggable_tags(runner, allow_text=True)
for tag, val in tags.items(): for tag, val in tags.items():
if isinstance(val, str): if isinstance(val, str):
...@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook):
self.writer.add_scalar(tag, val, self.get_iter(runner)) self.writer.add_scalar(tag, val, self.get_iter(runner))
@master_only @master_only
def after_run(self, runner): def after_run(self, runner) -> None:
self.writer.close() self.writer.close()
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import os import os
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook): ...@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
by_epoch=True, by_epoch: bool = True,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
interval_exp_name=1000, interval_exp_name: int = 1000,
out_dir=None, out_dir: Optional[str] = None,
out_suffix=('.log.json', '.log', '.py'), out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
keep_local=True, keep_local: bool = True,
file_client_args=None): file_client_args: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.time_sec_tot = 0 self.time_sec_tot = 0
...@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook): ...@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self.file_client = FileClient.infer_client(file_client_args, self.file_client = FileClient.infer_client(file_client_args,
self.out_dir) self.out_dir)
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
if self.out_dir is not None: if self.out_dir is not None:
...@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook): ...@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook):
if runner.meta is not None: if runner.meta is not None:
self._dump_log(runner.meta, runner) self._dump_log(runner.meta, runner)
def _get_max_memory(self, runner): def _get_max_memory(self, runner) -> int:
device = getattr(runner.model, 'output_device', None) device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device) mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)], mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
...@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook): ...@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook):
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb.item() return mem_mb.item()
def _log_info(self, log_dict, runner): def _log_info(self, log_dict: Dict, runner) -> None:
# print exp name for users to distinguish experiments # print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch # at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta: if runner.meta is not None and 'exp_name' in runner.meta:
...@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook): ...@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook):
lr_str = [] lr_str = []
for k, val in log_dict['lr'].items(): for k, val in log_dict['lr'].items():
lr_str.append(f'lr_{k}: {val:.3e}') lr_str.append(f'lr_{k}: {val:.3e}')
lr_str = ' '.join(lr_str) lr_str = ' '.join(lr_str) # type: ignore
else: else:
lr_str = f'lr: {log_dict["lr"]:.3e}' lr_str = f'lr: {log_dict["lr"]:.3e}' # type: ignore
# by epoch: Epoch [4][100/1000] # by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000] # by iter: Iter [100/100000]
...@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook): ...@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook):
runner.logger.info(log_str) runner.logger.info(log_str)
def _dump_log(self, log_dict, runner): def _dump_log(self, log_dict: Dict, runner) -> None:
# dump log in json format # dump log in json format
json_log = OrderedDict() json_log = OrderedDict()
for k, v in log_dict.items(): for k, v in log_dict.items():
...@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook): ...@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook):
else: else:
return items return items
def log(self, runner): def log(self, runner) -> OrderedDict:
if 'eval_iter_num' in runner.log_buffer.output: if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch # this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num') cur_iter = runner.log_buffer.output.pop('eval_iter_num')
...@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook): ...@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook):
if torch.cuda.is_available(): if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner) log_dict['memory'] = self._get_max_memory(runner)
log_dict = dict(log_dict, **runner.log_buffer.output) log_dict = dict(log_dict, **runner.log_buffer.output) # type: ignore
self._log_info(log_dict, runner) self._log_info(log_dict, runner)
self._dump_log(log_dict, runner) self._dump_log(log_dict, runner)
return log_dict return log_dict
def after_run(self, runner): def after_run(self, runner) -> None:
# copy or upload logs to self.out_dir # copy or upload logs to self.out_dir
if self.out_dir is not None: if self.out_dir is not None:
for filename in scandir(runner.work_dir, self.out_suffix, True): for filename in scandir(runner.work_dir, self.out_suffix, True):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from typing import Dict, Optional, Union
from mmcv.utils import scandir from mmcv.utils import scandir
from ...dist_utils import master_only from ...dist_utils import master_only
...@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook): ...@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
init_kwargs=None, init_kwargs: Optional[Dict] = None,
interval=10, interval: int = 10,
ignore_last=True, ignore_last: bool = True,
reset_flag=False, reset_flag: bool = False,
commit=True, commit: bool = True,
by_epoch=True, by_epoch: bool = True,
with_step=True, with_step: bool = True,
log_artifact=True, log_artifact: bool = True,
out_suffix=('.log.json', '.log', '.py')): out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
...@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook): ...@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook):
self.log_artifact = log_artifact self.log_artifact = log_artifact
self.out_suffix = out_suffix self.out_suffix = out_suffix
def import_wandb(self): def import_wandb(self) -> None:
try: try:
import wandb import wandb
except ImportError: except ImportError:
...@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook): ...@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook):
self.wandb = wandb self.wandb = wandb
@master_only @master_only
def before_run(self, runner): def before_run(self, runner) -> None:
super().before_run(runner) super().before_run(runner)
if self.wandb is None: if self.wandb is None:
self.import_wandb() self.import_wandb()
if self.init_kwargs: if self.init_kwargs:
self.wandb.init(**self.init_kwargs) self.wandb.init(**self.init_kwargs) # type: ignore
else: else:
self.wandb.init() self.wandb.init() # type: ignore
@master_only @master_only
def log(self, runner): def log(self, runner) -> None:
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
if self.with_step: if self.with_step:
...@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook): ...@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook):
self.wandb.log(tags, commit=self.commit) self.wandb.log(tags, commit=self.commit)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner) -> None:
if self.log_artifact: if self.log_artifact:
wandb_artifact = self.wandb.Artifact( wandb_artifact = self.wandb.Artifact(
name='artifacts', type='model') name='artifacts', type='model')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment