Unverified Commit ea173c9f authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/runer/hooks/logger (#2000)

* Fix

* Fix
parent c70fafeb
# Copyright (c) OpenMMLab. All rights reserved.
import numbers
from abc import ABCMeta, abstractmethod
from typing import Dict
import numpy as np
import torch
......@@ -23,10 +24,10 @@ class LoggerHook(Hook):
__metaclass__ = ABCMeta
def __init__(self,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True):
self.interval = interval
self.ignore_last = ignore_last
self.reset_flag = reset_flag
......@@ -37,7 +38,9 @@ class LoggerHook(Hook):
pass
@staticmethod
def is_scalar(val, include_np=True, include_torch=True):
def is_scalar(val,
include_np: bool = True,
include_torch: bool = True) -> bool:
"""Tell the input variable is a scalar or not.
Args:
......@@ -57,7 +60,7 @@ class LoggerHook(Hook):
else:
return False
def get_mode(self, runner):
def get_mode(self, runner) -> str:
if runner.mode == 'train':
if 'time' in runner.log_buffer.output:
mode = 'train'
......@@ -70,7 +73,7 @@ class LoggerHook(Hook):
f'but got {runner.mode}')
return mode
def get_epoch(self, runner):
def get_epoch(self, runner) -> int:
if runner.mode == 'train':
epoch = runner.epoch + 1
elif runner.mode == 'val':
......@@ -82,7 +85,7 @@ class LoggerHook(Hook):
f'but got {runner.mode}')
return epoch
def get_iter(self, runner, inner_iter=False):
def get_iter(self, runner, inner_iter: bool = False) -> int:
"""Get the current training iteration step."""
if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1
......@@ -90,7 +93,7 @@ class LoggerHook(Hook):
current_iter = runner.iter + 1
return current_iter
def get_lr_tags(self, runner):
def get_lr_tags(self, runner) -> Dict[str, float]:
tags = {}
lrs = runner.current_lr()
if isinstance(lrs, dict):
......@@ -100,7 +103,7 @@ class LoggerHook(Hook):
tags['learning_rate'] = lrs[0]
return tags
def get_momentum_tags(self, runner):
def get_momentum_tags(self, runner) -> Dict[str, float]:
tags = {}
momentums = runner.current_momentum()
if isinstance(momentums, dict):
......@@ -110,12 +113,14 @@ class LoggerHook(Hook):
tags['momentum'] = momentums[0]
return tags
def get_loggable_tags(self,
runner,
allow_scalar=True,
allow_text=False,
add_mode=True,
tags_to_skip=('time', 'data_time')):
def get_loggable_tags(
self,
runner,
allow_scalar: bool = True,
allow_text: bool = False,
add_mode: bool = True,
tags_to_skip: tuple = ('time', 'data_time')
) -> Dict:
tags = {}
for var, val in runner.log_buffer.output.items():
if var in tags_to_skip:
......@@ -131,16 +136,16 @@ class LoggerHook(Hook):
tags.update(self.get_momentum_tags(runner))
return tags
def before_run(self, runner):
def before_run(self, runner) -> None:
for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook):
hook.reset_flag = True
break
def before_epoch(self, runner):
def before_epoch(self, runner) -> None:
runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner):
def after_train_iter(self, runner) -> None:
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
......@@ -154,13 +159,13 @@ class LoggerHook(Hook):
if self.reset_flag:
runner.log_buffer.clear_output()
def after_train_epoch(self, runner):
def after_train_epoch(self, runner) -> None:
if runner.log_buffer.ready:
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_val_epoch(self, runner):
def after_val_epoch(self, runner) -> None:
runner.log_buffer.average()
self.log(runner)
if self.reset_flag:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -29,11 +31,11 @@ class ClearMLLoggerHook(LoggerHook):
"""
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
init_kwargs: Optional[Dict] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_clearml()
self.init_kwargs = init_kwargs
......@@ -47,14 +49,14 @@ class ClearMLLoggerHook(LoggerHook):
self.clearml = clearml
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger()
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
for tag, val in tags.items():
self.task_logger.report_scalar(tag, tag, val,
......
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Optional
from ...dist_utils import master_only
from ..hook import HOOKS
......@@ -31,17 +32,17 @@ class DvcliveLoggerHook(LoggerHook):
"""
def __init__(self,
model_file=None,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True,
model_file: Optional[str] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
**kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file
self.import_dvclive(**kwargs)
def import_dvclive(self, **kwargs):
def import_dvclive(self, **kwargs) -> None:
try:
from dvclive import Live
except ImportError:
......@@ -50,7 +51,7 @@ class DvcliveLoggerHook(LoggerHook):
self.dvclive = Live(**kwargs)
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
self.dvclive.set_step(self.get_iter(runner))
......@@ -58,7 +59,7 @@ class DvcliveLoggerHook(LoggerHook):
self.dvclive.log(k, v)
@master_only
def after_train_epoch(self, runner):
def after_train_epoch(self, runner) -> None:
super().after_train_epoch(runner)
if self.model_file is not None:
runner.save_checkpoint(
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
from ..hook import HOOKS
......@@ -33,20 +35,20 @@ class MlflowLoggerHook(LoggerHook):
"""
def __init__(self,
exp_name=None,
tags=None,
log_model=True,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
exp_name: Optional[str] = None,
tags: Optional[Dict] = None,
log_model: bool = True,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_mlflow()
self.exp_name = exp_name
self.tags = tags
self.log_model = log_model
def import_mlflow(self):
def import_mlflow(self) -> None:
try:
import mlflow
import mlflow.pytorch as mlflow_pytorch
......@@ -57,7 +59,7 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow_pytorch = mlflow_pytorch
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name)
......@@ -65,13 +67,13 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow.set_tags(self.tags)
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
self.mlflow.log_metrics(tags, step=self.get_iter(runner))
@master_only
def after_run(self, runner):
def after_run(self, runner) -> None:
if self.log_model:
self.mlflow_pytorch.log_model(
runner.model,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -42,19 +44,19 @@ class NeptuneLoggerHook(LoggerHook):
"""
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=True,
with_step=True,
by_epoch=True):
init_kwargs: Optional[Dict] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = True,
with_step: bool = True,
by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_neptune()
self.init_kwargs = init_kwargs
self.with_step = with_step
def import_neptune(self):
def import_neptune(self) -> None:
try:
import neptune.new as neptune
except ImportError:
......@@ -64,24 +66,24 @@ class NeptuneLoggerHook(LoggerHook):
self.run = None
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
if self.init_kwargs:
self.run = self.neptune.init(**self.init_kwargs)
else:
self.run = self.neptune.init()
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
for tag_name, tag_value in tags.items():
if self.with_step:
self.run[tag_name].log(
self.run[tag_name].log( # type: ignore
tag_value, step=self.get_iter(runner))
else:
tags['global_step'] = self.get_iter(runner)
self.run[tag_name].log(tags)
self.run[tag_name].log(tags) # type: ignore
@master_only
def after_run(self, runner):
self.run.stop()
def after_run(self, runner) -> None:
self.run.stop() # type: ignore
......@@ -2,6 +2,7 @@
import json
import os
import os.path as osp
from typing import Dict, Optional
import torch
import yaml
......@@ -32,14 +33,14 @@ class PaviLoggerHook(LoggerHook):
"""
def __init__(self,
init_kwargs=None,
add_graph=False,
add_last_ckpt=False,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True,
img_key='img_info'):
init_kwargs: Optional[Dict] = None,
add_graph: bool = False,
add_last_ckpt: bool = False,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
img_key: str = 'img_info'):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
......@@ -47,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
self.img_key = img_key
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
try:
from pavi import SummaryWriter
......@@ -85,7 +86,7 @@ class PaviLoggerHook(LoggerHook):
self.init_kwargs['session_text'] = session_text
self.writer = SummaryWriter(**self.init_kwargs)
def get_step(self, runner):
def get_step(self, runner) -> int:
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
......@@ -93,14 +94,14 @@ class PaviLoggerHook(LoggerHook):
return self.get_iter(runner)
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner, add_mode=False)
if tags:
self.writer.add_scalars(
self.get_mode(runner), tags, self.get_step(runner))
@master_only
def after_run(self, runner):
def after_run(self, runner) -> None:
if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
if osp.islink(ckpt_path):
......@@ -118,7 +119,7 @@ class PaviLoggerHook(LoggerHook):
self.writer.close()
@master_only
def before_epoch(self, runner):
def before_epoch(self, runner) -> None:
if runner.epoch == 0 and self.add_graph:
if is_module_wrapper(runner.model):
_model = runner.model.module
......
......@@ -23,14 +23,14 @@ class SegmindLoggerHook(LoggerHook):
"""
def __init__(self,
interval=10,
ignore_last=True,
reset_flag=False,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch=True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_segmind()
def import_segmind(self):
def import_segmind(self) -> None:
try:
import segmind
except ImportError:
......@@ -40,7 +40,7 @@ class SegmindLoggerHook(LoggerHook):
self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
# logging metrics to segmind
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Optional
from mmcv.utils import TORCH_VERSION, digit_version
from ...dist_utils import master_only
......@@ -23,16 +24,16 @@ class TensorboardLoggerHook(LoggerHook):
"""
def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
log_dir: Optional[str] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')):
......@@ -55,7 +56,7 @@ class TensorboardLoggerHook(LoggerHook):
self.writer = SummaryWriter(self.log_dir)
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner, allow_text=True)
for tag, val in tags.items():
if isinstance(val, str):
......@@ -64,5 +65,5 @@ class TensorboardLoggerHook(LoggerHook):
self.writer.add_scalar(tag, val, self.get_iter(runner))
@master_only
def after_run(self, runner):
def after_run(self, runner) -> None:
self.writer.close()
......@@ -3,6 +3,7 @@ import datetime
import os
import os.path as osp
from collections import OrderedDict
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
......@@ -53,15 +54,15 @@ class TextLoggerHook(LoggerHook):
"""
def __init__(self,
by_epoch=True,
interval=10,
ignore_last=True,
reset_flag=False,
interval_exp_name=1000,
out_dir=None,
out_suffix=('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None):
by_epoch: bool = True,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
interval_exp_name: int = 1000,
out_dir: Optional[str] = None,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
keep_local: bool = True,
file_client_args: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
......@@ -85,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self.file_client = FileClient.infer_client(file_client_args,
self.out_dir)
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
if self.out_dir is not None:
......@@ -105,7 +106,7 @@ class TextLoggerHook(LoggerHook):
if runner.meta is not None:
self._dump_log(runner.meta, runner)
def _get_max_memory(self, runner):
def _get_max_memory(self, runner) -> int:
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
......@@ -115,7 +116,7 @@ class TextLoggerHook(LoggerHook):
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb.item()
def _log_info(self, log_dict, runner):
def _log_info(self, log_dict: Dict, runner) -> None:
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
......@@ -129,9 +130,9 @@ class TextLoggerHook(LoggerHook):
lr_str = []
for k, val in log_dict['lr'].items():
lr_str.append(f'lr_{k}: {val:.3e}')
lr_str = ' '.join(lr_str)
lr_str = ' '.join(lr_str) # type: ignore
else:
lr_str = f'lr: {log_dict["lr"]:.3e}'
lr_str = f'lr: {log_dict["lr"]:.3e}' # type: ignore
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
......@@ -181,7 +182,7 @@ class TextLoggerHook(LoggerHook):
runner.logger.info(log_str)
def _dump_log(self, log_dict, runner):
def _dump_log(self, log_dict: Dict, runner) -> None:
# dump log in json format
json_log = OrderedDict()
for k, v in log_dict.items():
......@@ -200,7 +201,7 @@ class TextLoggerHook(LoggerHook):
else:
return items
def log(self, runner):
def log(self, runner) -> OrderedDict:
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
......@@ -228,13 +229,13 @@ class TextLoggerHook(LoggerHook):
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)
log_dict = dict(log_dict, **runner.log_buffer.output)
log_dict = dict(log_dict, **runner.log_buffer.output) # type: ignore
self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
return log_dict
def after_run(self, runner):
def after_run(self, runner) -> None:
# copy or upload logs to self.out_dir
if self.out_dir is not None:
for filename in scandir(runner.work_dir, self.out_suffix, True):
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, Optional, Union
from mmcv.utils import scandir
from ...dist_utils import master_only
......@@ -48,15 +49,15 @@ class WandbLoggerHook(LoggerHook):
"""
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=False,
commit=True,
by_epoch=True,
with_step=True,
log_artifact=True,
out_suffix=('.log.json', '.log', '.py')):
init_kwargs: Optional[Dict] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
commit: bool = True,
by_epoch: bool = True,
with_step: bool = True,
log_artifact: bool = True,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
......@@ -65,7 +66,7 @@ class WandbLoggerHook(LoggerHook):
self.log_artifact = log_artifact
self.out_suffix = out_suffix
def import_wandb(self):
def import_wandb(self) -> None:
try:
import wandb
except ImportError:
......@@ -74,17 +75,17 @@ class WandbLoggerHook(LoggerHook):
self.wandb = wandb
@master_only
def before_run(self, runner):
def before_run(self, runner) -> None:
super().before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
self.wandb.init(**self.init_kwargs)
self.wandb.init(**self.init_kwargs) # type: ignore
else:
self.wandb.init()
self.wandb.init() # type: ignore
@master_only
def log(self, runner):
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
if self.with_step:
......@@ -95,7 +96,7 @@ class WandbLoggerHook(LoggerHook):
self.wandb.log(tags, commit=self.commit)
@master_only
def after_run(self, runner):
def after_run(self, runner) -> None:
if self.log_artifact:
wandb_artifact = self.wandb.Artifact(
name='artifacts', type='model')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment