Unverified Commit 3dd2a21b authored by Yvette Zhao's avatar Yvette Zhao Committed by GitHub
Browse files

Add type hints for runner/base_runner (#2003)



* Add type hints for runner

* refine

* fix error

* refine

* refine format
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent bb710675
......@@ -4,9 +4,13 @@ import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
no_type_check)
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv
from ..parallel import is_module_wrapper
......@@ -49,14 +53,14 @@ class BaseRunner(metaclass=ABCMeta):
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
model: torch.nn.Module,
batch_processor: Optional[Callable] = None,
optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
work_dir: Optional[str] = None,
logger: Optional[logging.Logger] = None,
meta: Optional[Dict] = None,
max_iters: Optional[int] = None,
max_epochs: Optional[int] = None) -> None:
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
......@@ -106,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
if isinstance(work_dir, str):
self.work_dir: Optional[str] = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
......@@ -122,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
self._hooks = []
self.mode: Optional[str] = None
self._hooks: List[Hook] = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
......@@ -138,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
self.log_buffer = LogBuffer()
@property
def model_name(self):
def model_name(self) -> str:
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self):
def rank(self) -> int:
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self):
def world_size(self) -> int:
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self):
def hooks(self) -> List[Hook]:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self):
def epoch(self) -> int:
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
def iter(self) -> int:
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self):
def inner_iter(self) -> int:
"""int: Iteration in an epoch."""
return self._inner_iter
......@@ -192,19 +196,20 @@ class BaseRunner(metaclass=ABCMeta):
pass
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
def run(self, data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]], **kwargs) -> Any:
pass
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
out_dir: str,
filename_tmpl: str,
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
pass
def current_lr(self):
def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current learning rates.
Returns:
......@@ -212,6 +217,7 @@ class BaseRunner(metaclass=ABCMeta):
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
lr: Union[List[float], Dict[str, List[float]]]
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
......@@ -223,7 +229,7 @@ class BaseRunner(metaclass=ABCMeta):
'lr is not applicable because optimizer does not exist.')
return lr
def current_momentum(self):
def current_momentum(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current momentums.
Returns:
......@@ -254,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
momentums[name] = _get_momentum(optim)
return momentums
def register_hook(self, hook, priority='NORMAL'):
def register_hook(self,
hook: Hook,
priority: Union[int, str, Priority] = 'NORMAL') -> None:
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
......@@ -271,18 +279,18 @@ class BaseRunner(metaclass=ABCMeta):
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
hook.priority = priority # type: ignore
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
if priority >= self._hooks[i].priority: # type: ignore
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
"""Register a hook from its cfg.
Args:
......@@ -298,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name):
def call_hook(self, fn_name: str) -> None:
"""Call all hooks.
Args:
......@@ -308,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def get_hook_info(self):
def get_hook_info(self) -> str:
# Get hooks info in each stage
stage_hook_map = {stage: [] for stage in Hook.stages}
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
......@@ -331,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
def load_checkpoint(self,
filename,
map_location='cpu',
strict=False,
revise_keys=[(r'^module.', '')]):
def load_checkpoint(
self,
filename: str,
map_location: Union[str, Callable] = 'cpu',
strict: bool = False,
revise_keys: List = [(r'^module.', '')],
) -> Union[Dict, OrderedDict]:
return load_checkpoint(
self.model,
filename,
......@@ -344,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
self.logger,
revise_keys=revise_keys)
@no_type_check
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
checkpoint: str,
resume_optimizer: bool = True,
map_location: Union[str, Callable] = 'default') -> None:
if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
......@@ -398,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def register_lr_hook(self, lr_config):
def register_lr_hook(self, lr_config: Union[Dict, Hook, None]) -> None:
if lr_config is None:
return
elif isinstance(lr_config, dict):
......@@ -419,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = lr_config
self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config):
def register_momentum_hook(
self, momentum_config: Union[Dict, Hook, None]) -> None:
if momentum_config is None:
return
if isinstance(momentum_config, dict):
......@@ -440,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = momentum_config
self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config):
def register_optimizer_hook(
self, optimizer_config: Union[Dict, Hook, None]) -> None:
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
......@@ -450,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config):
def register_checkpoint_hook(
self, checkpoint_config: Union[Dict, Hook, None]) -> None:
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
......@@ -460,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = checkpoint_config
self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config):
def register_logger_hooks(self, log_config: Optional[Dict]) -> None:
if log_config is None:
return
log_interval = log_config['interval']
......@@ -469,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
def register_timer_hook(
self,
timer_config: Union[Dict, Hook, None],
) -> None:
if timer_config is None:
return
if isinstance(timer_config, dict):
......@@ -479,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = timer_config
self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config):
def register_custom_hooks(
self, custom_config: Union[List, Dict, Hook, None]) -> None:
if custom_config is None:
return
......@@ -492,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
else:
self.register_hook(item, priority='NORMAL')
def register_profiler_hook(self, profiler_config):
def register_profiler_hook(
self,
profiler_config: Union[Dict, Hook, None],
) -> None:
if profiler_config is None:
return
if isinstance(profiler_config, dict):
......@@ -502,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
hook = profiler_config
self.register_hook(hook)
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
def register_training_hooks(
self,
lr_config: Union[Dict, Hook, None],
optimizer_config: Union[Dict, Hook, None] = None,
checkpoint_config: Union[Dict, Hook, None] = None,
log_config: Optional[Dict] = None,
momentum_config: Union[Dict, Hook, None] = None,
timer_config: Union[Dict, Hook] = dict(type='IterTimerHook'),
custom_hooks_config: Union[List, Dict, Hook, None] = None) -> None:
"""Register default and custom hooks for training.
Default and custom hooks include:
......
......@@ -276,7 +276,7 @@ class CheckpointLoader:
def load_checkpoint(
cls,
filename: str,
map_location: Optional[str] = None,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path.
......@@ -301,8 +301,9 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path.
Args:
......@@ -322,7 +323,7 @@ def load_from_local(
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(
filename: str,
map_location: Optional[str] = None,
map_location: Union[str, Callable, None] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
......@@ -351,8 +352,9 @@ def load_from_http(
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
......@@ -385,7 +387,7 @@ def load_from_pavi(
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename: str,
map_location: Optional[str] = None,
map_location: Union[str, Callable, None] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
......@@ -434,8 +436,9 @@ def load_from_ceph(filename: str,
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
......@@ -465,8 +468,9 @@ def load_from_torchvision(
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
......@@ -509,8 +513,9 @@ def load_from_openmmlab(
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls.
Args:
......@@ -531,7 +536,7 @@ def load_from_mmcls(
def _load_checkpoint(
filename: str,
map_location: Optional[str] = None,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url).
......@@ -553,9 +558,10 @@ def _load_checkpoint(
def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
prefix: str,
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix.
Args:
......@@ -591,7 +597,7 @@ def _load_checkpoint_with_prefix(
def load_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Optional[str] = None,
map_location: Union[str, Callable, None] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
......
......@@ -4,8 +4,10 @@ import platform
import shutil
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import mmcv
from .base_runner import BaseRunner
......@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
......@@ -72,7 +74,11 @@ class EpochBasedRunner(BaseRunner):
del self.data_batch
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
......@@ -133,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
out_dir: str,
filename_tmpl: str = 'epoch_{}.pth',
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
"""Save the checkpoint.
Args:
......
......@@ -119,10 +119,11 @@ class LrUpdaterHook(Hook):
]
self.base_lr.update({k: _base_lr})
else:
for group in runner.optimizer.param_groups:
for group in runner.optimizer.param_groups: # type: ignore
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
group['initial_lr']
for group in runner.optimizer.param_groups # type: ignore
]
def before_train_epoch(self, runner: 'runner.BaseRunner'):
......@@ -138,6 +139,7 @@ class LrUpdaterHook(Hook):
def before_train_iter(self, runner: 'runner.BaseRunner'):
cur_iter = runner.iter
assert isinstance(self.warmup_iters, int)
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
......@@ -505,7 +507,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
# total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up *
self.max_iter_per_phase) # type:ignore
self.max_iter_per_phase) # type: ignore
self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]])
self.lr_phases.append([
iter_up_phase, self.max_iter_per_phase, self.target_ratio[0],
......@@ -513,8 +515,8 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
])
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter % self.max_iter_per_phase
curr_cycle = runner.iter // self.max_iter_per_phase
curr_iter = runner.iter % self.max_iter_per_phase # type: ignore
curr_cycle = runner.iter // self.max_iter_per_phase # type: ignore
# Update weight decay
scale = self.gamma**curr_cycle
......@@ -637,7 +639,8 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
k = type(runner.optimizer).__name__
_max_lr = format_param(k, runner.optimizer, self._max_lr)
self.base_lr = [lr / self.div_factor for lr in _max_lr]
for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
optim_param_groups = runner.optimizer.param_groups # type: ignore
for group, lr in zip(optim_param_groups, self.base_lr):
group.setdefault('initial_lr', lr)
if self.three_phase:
......
......@@ -4,9 +4,11 @@ import platform
import shutil
import time
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union, no_type_check
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv
from .base_runner import BaseRunner
......@@ -18,13 +20,13 @@ from .utils import get_host_info
class IterLoader:
def __init__(self, dataloader):
def __init__(self, dataloader: DataLoader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
def epoch(self) -> int:
return self._epoch
def __next__(self):
......@@ -88,7 +90,11 @@ class IterBasedRunner(BaseRunner):
del self.data_batch
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_iters: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
......@@ -141,10 +147,11 @@ class IterBasedRunner(BaseRunner):
self.call_hook('after_epoch')
self.call_hook('after_run')
@no_type_check
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
checkpoint: str,
resume_optimizer: bool = True,
map_location: Union[str, Callable] = 'default') -> None:
"""Resume model from checkpoint.
Args:
......@@ -180,12 +187,13 @@ class IterBasedRunner(BaseRunner):
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
def save_checkpoint( # type: ignore
self,
out_dir: str,
filename_tmpl: str = 'iter_{}.pth',
meta: Optional[Dict] = None,
save_optimizer: bool = True,
create_symlink: bool = True) -> None:
"""Save checkpoint to file.
Args:
......@@ -261,9 +269,9 @@ class IterBasedRunner(BaseRunner):
will be triggered after default hooks.
"""
if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False)
checkpoint_config.setdefault('by_epoch', False) # type: ignore
if lr_config is not None:
lr_config.setdefault('by_epoch', False)
lr_config.setdefault('by_epoch', False) # type: ignore
if log_config is not None:
for info in log_config['hooks']:
info.setdefault('by_epoch', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment