Unverified Commit 3dd2a21b authored by Yvette Zhao's avatar Yvette Zhao Committed by GitHub
Browse files

Add type hints for runner/base_runner (#2003)



* Add type hints for runner

* refine

* fix error

* refine

* refine format
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent bb710675
...@@ -4,9 +4,13 @@ import logging ...@@ -4,9 +4,13 @@ import logging
import os.path as osp import os.path as osp
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
no_type_check)
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv import mmcv
from ..parallel import is_module_wrapper from ..parallel import is_module_wrapper
...@@ -49,14 +53,14 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -49,14 +53,14 @@ class BaseRunner(metaclass=ABCMeta):
""" """
def __init__(self, def __init__(self,
model, model: torch.nn.Module,
batch_processor=None, batch_processor: Optional[Callable] = None,
optimizer=None, optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
work_dir=None, work_dir: Optional[str] = None,
logger=None, logger: Optional[logging.Logger] = None,
meta=None, meta: Optional[Dict] = None,
max_iters=None, max_iters: Optional[int] = None,
max_epochs=None): max_epochs: Optional[int] = None) -> None:
if batch_processor is not None: if batch_processor is not None:
if not callable(batch_processor): if not callable(batch_processor):
raise TypeError('batch_processor must be callable, ' raise TypeError('batch_processor must be callable, '
...@@ -106,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -106,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
self.logger = logger self.logger = logger
self.meta = meta self.meta = meta
# create work_dir # create work_dir
if mmcv.is_str(work_dir): if isinstance(work_dir, str):
self.work_dir = osp.abspath(work_dir) self.work_dir: Optional[str] = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir) mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None: elif work_dir is None:
self.work_dir = None self.work_dir = None
...@@ -122,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -122,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
self._rank, self._world_size = get_dist_info() self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str() self.timestamp = get_time_str()
self.mode = None self.mode: Optional[str] = None
self._hooks = [] self._hooks: List[Hook] = []
self._epoch = 0 self._epoch = 0
self._iter = 0 self._iter = 0
self._inner_iter = 0 self._inner_iter = 0
...@@ -138,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -138,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
self.log_buffer = LogBuffer() self.log_buffer = LogBuffer()
@property @property
def model_name(self): def model_name(self) -> str:
"""str: Name of the model, usually the module class name.""" """str: Name of the model, usually the module class name."""
return self._model_name return self._model_name
@property @property
def rank(self): def rank(self) -> int:
"""int: Rank of current process. (distributed training)""" """int: Rank of current process. (distributed training)"""
return self._rank return self._rank
@property @property
def world_size(self): def world_size(self) -> int:
"""int: Number of processes participating in the job. """int: Number of processes participating in the job.
(distributed training)""" (distributed training)"""
return self._world_size return self._world_size
@property @property
def hooks(self): def hooks(self) -> List[Hook]:
"""list[:obj:`Hook`]: A list of registered hooks.""" """list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
@property @property
def epoch(self): def epoch(self) -> int:
"""int: Current epoch.""" """int: Current epoch."""
return self._epoch return self._epoch
@property @property
def iter(self): def iter(self) -> int:
"""int: Current iteration.""" """int: Current iteration."""
return self._iter return self._iter
@property @property
def inner_iter(self): def inner_iter(self) -> int:
"""int: Iteration in an epoch.""" """int: Iteration in an epoch."""
return self._inner_iter return self._inner_iter
...@@ -192,19 +196,20 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -192,19 +196,20 @@ class BaseRunner(metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def run(self, data_loaders, workflow, **kwargs): def run(self, data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]], **kwargs) -> Any:
pass pass
@abstractmethod @abstractmethod
def save_checkpoint(self, def save_checkpoint(self,
out_dir, out_dir: str,
filename_tmpl, filename_tmpl: str,
save_optimizer=True, save_optimizer: bool = True,
meta=None, meta: Optional[Dict] = None,
create_symlink=True): create_symlink: bool = True) -> None:
pass pass
def current_lr(self): def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current learning rates. """Get current learning rates.
Returns: Returns:
...@@ -212,6 +217,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -212,6 +217,7 @@ class BaseRunner(metaclass=ABCMeta):
param groups. If the runner has a dict of optimizers, this method param groups. If the runner has a dict of optimizers, this method
will return a dict. will return a dict.
""" """
lr: Union[List[float], Dict[str, List[float]]]
if isinstance(self.optimizer, torch.optim.Optimizer): if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups] lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict): elif isinstance(self.optimizer, dict):
...@@ -223,7 +229,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -223,7 +229,7 @@ class BaseRunner(metaclass=ABCMeta):
'lr is not applicable because optimizer does not exist.') 'lr is not applicable because optimizer does not exist.')
return lr return lr
def current_momentum(self): def current_momentum(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current momentums. """Get current momentums.
Returns: Returns:
...@@ -254,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -254,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
momentums[name] = _get_momentum(optim) momentums[name] = _get_momentum(optim)
return momentums return momentums
def register_hook(self, hook, priority='NORMAL'): def register_hook(self,
hook: Hook,
priority: Union[int, str, Priority] = 'NORMAL') -> None:
"""Register a hook into the hook list. """Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified The hook will be inserted into a priority queue, with the specified
...@@ -271,18 +279,18 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -271,18 +279,18 @@ class BaseRunner(metaclass=ABCMeta):
if hasattr(hook, 'priority'): if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks') raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority) priority = get_priority(priority)
hook.priority = priority hook.priority = priority # type: ignore
# insert the hook to a sorted list # insert the hook to a sorted list
inserted = False inserted = False
for i in range(len(self._hooks) - 1, -1, -1): for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority: if priority >= self._hooks[i].priority: # type: ignore
self._hooks.insert(i + 1, hook) self._hooks.insert(i + 1, hook)
inserted = True inserted = True
break break
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg): def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
"""Register a hook from its cfg. """Register a hook from its cfg.
Args: Args:
...@@ -298,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -298,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(hook_cfg, HOOKS) hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority) self.register_hook(hook, priority=priority)
def call_hook(self, fn_name): def call_hook(self, fn_name: str) -> None:
"""Call all hooks. """Call all hooks.
Args: Args:
...@@ -308,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -308,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
def get_hook_info(self): def get_hook_info(self) -> str:
# Get hooks info in each stage # Get hooks info in each stage
stage_hook_map = {stage: [] for stage in Hook.stages} stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks: for hook in self.hooks:
try: try:
priority = Priority(hook.priority).name priority = Priority(hook.priority).name # type: ignore
except ValueError: except ValueError:
priority = hook.priority priority = hook.priority # type: ignore
classname = hook.__class__.__name__ classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}' hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages(): for trigger_stage in hook.get_triggered_stages():
...@@ -331,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -331,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
stage_hook_infos.append(info) stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos) return '\n'.join(stage_hook_infos)
def load_checkpoint(self, def load_checkpoint(
filename, self,
map_location='cpu', filename: str,
strict=False, map_location: Union[str, Callable] = 'cpu',
revise_keys=[(r'^module.', '')]): strict: bool = False,
revise_keys: List = [(r'^module.', '')],
) -> Union[Dict, OrderedDict]:
return load_checkpoint( return load_checkpoint(
self.model, self.model,
filename, filename,
...@@ -344,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -344,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
self.logger, self.logger,
revise_keys=revise_keys) revise_keys=revise_keys)
@no_type_check
def resume(self, def resume(self,
checkpoint, checkpoint: str,
resume_optimizer=True, resume_optimizer: bool = True,
map_location='default'): map_location: Union[str, Callable] = 'default') -> None:
if map_location == 'default': if map_location == 'default':
if torch.cuda.is_available(): if torch.cuda.is_available():
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
...@@ -398,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -398,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def register_lr_hook(self, lr_config): def register_lr_hook(self, lr_config: Union[Dict, Hook, None]) -> None:
if lr_config is None: if lr_config is None:
return return
elif isinstance(lr_config, dict): elif isinstance(lr_config, dict):
...@@ -419,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -419,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = lr_config hook = lr_config
self.register_hook(hook, priority='VERY_HIGH') self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config): def register_momentum_hook(
self, momentum_config: Union[Dict, Hook, None]) -> None:
if momentum_config is None: if momentum_config is None:
return return
if isinstance(momentum_config, dict): if isinstance(momentum_config, dict):
...@@ -440,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -440,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = momentum_config hook = momentum_config
self.register_hook(hook, priority='HIGH') self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config): def register_optimizer_hook(
self, optimizer_config: Union[Dict, Hook, None]) -> None:
if optimizer_config is None: if optimizer_config is None:
return return
if isinstance(optimizer_config, dict): if isinstance(optimizer_config, dict):
...@@ -450,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -450,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = optimizer_config hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL') self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config): def register_checkpoint_hook(
self, checkpoint_config: Union[Dict, Hook, None]) -> None:
if checkpoint_config is None: if checkpoint_config is None:
return return
if isinstance(checkpoint_config, dict): if isinstance(checkpoint_config, dict):
...@@ -460,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -460,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = checkpoint_config hook = checkpoint_config
self.register_hook(hook, priority='NORMAL') self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config): def register_logger_hooks(self, log_config: Optional[Dict]) -> None:
if log_config is None: if log_config is None:
return return
log_interval = log_config['interval'] log_interval = log_config['interval']
...@@ -469,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -469,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
info, HOOKS, default_args=dict(interval=log_interval)) info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW') self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config): def register_timer_hook(
self,
timer_config: Union[Dict, Hook, None],
) -> None:
if timer_config is None: if timer_config is None:
return return
if isinstance(timer_config, dict): if isinstance(timer_config, dict):
...@@ -479,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -479,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = timer_config hook = timer_config
self.register_hook(hook, priority='LOW') self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config): def register_custom_hooks(
self, custom_config: Union[List, Dict, Hook, None]) -> None:
if custom_config is None: if custom_config is None:
return return
...@@ -492,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -492,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
else: else:
self.register_hook(item, priority='NORMAL') self.register_hook(item, priority='NORMAL')
def register_profiler_hook(self, profiler_config): def register_profiler_hook(
self,
profiler_config: Union[Dict, Hook, None],
) -> None:
if profiler_config is None: if profiler_config is None:
return return
if isinstance(profiler_config, dict): if isinstance(profiler_config, dict):
...@@ -502,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -502,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
hook = profiler_config hook = profiler_config
self.register_hook(hook) self.register_hook(hook)
def register_training_hooks(self, def register_training_hooks(
lr_config, self,
optimizer_config=None, lr_config: Union[Dict, Hook, None],
checkpoint_config=None, optimizer_config: Union[Dict, Hook, None] = None,
log_config=None, checkpoint_config: Union[Dict, Hook, None] = None,
momentum_config=None, log_config: Optional[Dict] = None,
timer_config=dict(type='IterTimerHook'), momentum_config: Union[Dict, Hook, None] = None,
custom_hooks_config=None): timer_config: Union[Dict, Hook] = dict(type='IterTimerHook'),
custom_hooks_config: Union[List, Dict, Hook, None] = None) -> None:
"""Register default and custom hooks for training. """Register default and custom hooks for training.
Default and custom hooks include: Default and custom hooks include:
......
...@@ -276,7 +276,7 @@ class CheckpointLoader: ...@@ -276,7 +276,7 @@ class CheckpointLoader:
def load_checkpoint( def load_checkpoint(
cls, cls,
filename: str, filename: str,
map_location: Optional[str] = None, map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]: ) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path. """load checkpoint through URL scheme path.
...@@ -301,8 +301,9 @@ class CheckpointLoader: ...@@ -301,8 +301,9 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='') @CheckpointLoader.register_scheme(prefixes='')
def load_from_local( def load_from_local(
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path. """load checkpoint by local file path.
Args: Args:
...@@ -322,7 +323,7 @@ def load_from_local( ...@@ -322,7 +323,7 @@ def load_from_local(
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) @CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http( def load_from_http(
filename: str, filename: str,
map_location: Optional[str] = None, map_location: Union[str, Callable, None] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]: model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed """load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function only download checkpoint at local rank 0.
...@@ -351,8 +352,9 @@ def load_from_http( ...@@ -351,8 +352,9 @@ def load_from_http(
@CheckpointLoader.register_scheme(prefixes='pavi://') @CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi( def load_from_pavi(
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed """load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
...@@ -385,7 +387,7 @@ def load_from_pavi( ...@@ -385,7 +387,7 @@ def load_from_pavi(
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') @CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename: str, def load_from_ceph(filename: str,
map_location: Optional[str] = None, map_location: Union[str, Callable, None] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]: backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed """load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
...@@ -434,8 +436,9 @@ def load_from_ceph(filename: str, ...@@ -434,8 +436,9 @@ def load_from_ceph(filename: str,
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision( def load_from_torchvision(
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or """load checkpoint through the file path prefixed with modelzoo or
torchvision. torchvision.
...@@ -465,8 +468,9 @@ def load_from_torchvision( ...@@ -465,8 +468,9 @@ def load_from_torchvision(
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab( def load_from_openmmlab(
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or """load checkpoint through the file path prefixed with open-mmlab or
openmmlab. openmmlab.
...@@ -509,8 +513,9 @@ def load_from_openmmlab( ...@@ -509,8 +513,9 @@ def load_from_openmmlab(
@CheckpointLoader.register_scheme(prefixes='mmcls://') @CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls( def load_from_mmcls(
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls. """load checkpoint through the file path prefixed with mmcls.
Args: Args:
...@@ -531,7 +536,7 @@ def load_from_mmcls( ...@@ -531,7 +536,7 @@ def load_from_mmcls(
def _load_checkpoint( def _load_checkpoint(
filename: str, filename: str,
map_location: Optional[str] = None, map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]: logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url). """Load checkpoint from somewhere (modelzoo, file, url).
...@@ -553,9 +558,10 @@ def _load_checkpoint( ...@@ -553,9 +558,10 @@ def _load_checkpoint(
def _load_checkpoint_with_prefix( def _load_checkpoint_with_prefix(
prefix: str, prefix: str,
filename: str, filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]: map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix. """Load partial pretrained model with specific prefix.
Args: Args:
...@@ -591,7 +597,7 @@ def _load_checkpoint_with_prefix( ...@@ -591,7 +597,7 @@ def _load_checkpoint_with_prefix(
def load_checkpoint( def load_checkpoint(
model: torch.nn.Module, model: torch.nn.Module,
filename: str, filename: str,
map_location: Optional[str] = None, map_location: Union[str, Callable, None] = None,
strict: bool = False, strict: bool = False,
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]: revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
......
...@@ -4,8 +4,10 @@ import platform ...@@ -4,8 +4,10 @@ import platform
import shutil import shutil
import time import time
import warnings import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch.utils.data import DataLoader
import mmcv import mmcv
from .base_runner import BaseRunner from .base_runner import BaseRunner
...@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner): ...@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch. This runner train models epoch by epoch.
""" """
def run_iter(self, data_batch, train_mode, **kwargs): def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
if self.batch_processor is not None: if self.batch_processor is not None:
outputs = self.batch_processor( outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs) self.model, data_batch, train_mode=train_mode, **kwargs)
...@@ -72,7 +74,11 @@ class EpochBasedRunner(BaseRunner): ...@@ -72,7 +74,11 @@ class EpochBasedRunner(BaseRunner):
del self.data_batch del self.data_batch
self.call_hook('after_val_epoch') self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs): def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running. """Start running.
Args: Args:
...@@ -133,11 +139,11 @@ class EpochBasedRunner(BaseRunner): ...@@ -133,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_run') self.call_hook('after_run')
def save_checkpoint(self, def save_checkpoint(self,
out_dir, out_dir: str,
filename_tmpl='epoch_{}.pth', filename_tmpl: str = 'epoch_{}.pth',
save_optimizer=True, save_optimizer: bool = True,
meta=None, meta: Optional[Dict] = None,
create_symlink=True): create_symlink: bool = True) -> None:
"""Save the checkpoint. """Save the checkpoint.
Args: Args:
......
...@@ -119,10 +119,11 @@ class LrUpdaterHook(Hook): ...@@ -119,10 +119,11 @@ class LrUpdaterHook(Hook):
] ]
self.base_lr.update({k: _base_lr}) self.base_lr.update({k: _base_lr})
else: else:
for group in runner.optimizer.param_groups: for group in runner.optimizer.param_groups: # type: ignore
group.setdefault('initial_lr', group['lr']) group.setdefault('initial_lr', group['lr'])
self.base_lr = [ self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups group['initial_lr']
for group in runner.optimizer.param_groups # type: ignore
] ]
def before_train_epoch(self, runner: 'runner.BaseRunner'): def before_train_epoch(self, runner: 'runner.BaseRunner'):
...@@ -138,6 +139,7 @@ class LrUpdaterHook(Hook): ...@@ -138,6 +139,7 @@ class LrUpdaterHook(Hook):
def before_train_iter(self, runner: 'runner.BaseRunner'): def before_train_iter(self, runner: 'runner.BaseRunner'):
cur_iter = runner.iter cur_iter = runner.iter
assert isinstance(self.warmup_iters, int)
if not self.by_epoch: if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner) self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters: if self.warmup is None or cur_iter >= self.warmup_iters:
...@@ -505,7 +507,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -505,7 +507,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
# total lr_phases are separated as up and down # total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times self.max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * iter_up_phase = int(self.step_ratio_up *
self.max_iter_per_phase) # type:ignore self.max_iter_per_phase) # type: ignore
self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]]) self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]])
self.lr_phases.append([ self.lr_phases.append([
iter_up_phase, self.max_iter_per_phase, self.target_ratio[0], iter_up_phase, self.max_iter_per_phase, self.target_ratio[0],
...@@ -513,8 +515,8 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -513,8 +515,8 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
]) ])
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter % self.max_iter_per_phase curr_iter = runner.iter % self.max_iter_per_phase # type: ignore
curr_cycle = runner.iter // self.max_iter_per_phase curr_cycle = runner.iter // self.max_iter_per_phase # type: ignore
# Update weight decay # Update weight decay
scale = self.gamma**curr_cycle scale = self.gamma**curr_cycle
...@@ -637,7 +639,8 @@ class OneCycleLrUpdaterHook(LrUpdaterHook): ...@@ -637,7 +639,8 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
k = type(runner.optimizer).__name__ k = type(runner.optimizer).__name__
_max_lr = format_param(k, runner.optimizer, self._max_lr) _max_lr = format_param(k, runner.optimizer, self._max_lr)
self.base_lr = [lr / self.div_factor for lr in _max_lr] self.base_lr = [lr / self.div_factor for lr in _max_lr]
for group, lr in zip(runner.optimizer.param_groups, self.base_lr): optim_param_groups = runner.optimizer.param_groups # type: ignore
for group, lr in zip(optim_param_groups, self.base_lr):
group.setdefault('initial_lr', lr) group.setdefault('initial_lr', lr)
if self.three_phase: if self.three_phase:
......
...@@ -4,9 +4,11 @@ import platform ...@@ -4,9 +4,11 @@ import platform
import shutil import shutil
import time import time
import warnings import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union, no_type_check
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv import mmcv
from .base_runner import BaseRunner from .base_runner import BaseRunner
...@@ -18,13 +20,13 @@ from .utils import get_host_info ...@@ -18,13 +20,13 @@ from .utils import get_host_info
class IterLoader: class IterLoader:
def __init__(self, dataloader): def __init__(self, dataloader: DataLoader):
self._dataloader = dataloader self._dataloader = dataloader
self.iter_loader = iter(self._dataloader) self.iter_loader = iter(self._dataloader)
self._epoch = 0 self._epoch = 0
@property @property
def epoch(self): def epoch(self) -> int:
return self._epoch return self._epoch
def __next__(self): def __next__(self):
...@@ -88,7 +90,11 @@ class IterBasedRunner(BaseRunner): ...@@ -88,7 +90,11 @@ class IterBasedRunner(BaseRunner):
del self.data_batch del self.data_batch
self._inner_iter += 1 self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs): def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_iters: Optional[int] = None,
**kwargs) -> None:
"""Start running. """Start running.
Args: Args:
...@@ -141,10 +147,11 @@ class IterBasedRunner(BaseRunner): ...@@ -141,10 +147,11 @@ class IterBasedRunner(BaseRunner):
self.call_hook('after_epoch') self.call_hook('after_epoch')
self.call_hook('after_run') self.call_hook('after_run')
@no_type_check
def resume(self, def resume(self,
checkpoint, checkpoint: str,
resume_optimizer=True, resume_optimizer: bool = True,
map_location='default'): map_location: Union[str, Callable] = 'default') -> None:
"""Resume model from checkpoint. """Resume model from checkpoint.
Args: Args:
...@@ -180,12 +187,13 @@ class IterBasedRunner(BaseRunner): ...@@ -180,12 +187,13 @@ class IterBasedRunner(BaseRunner):
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self, def save_checkpoint( # type: ignore
out_dir, self,
filename_tmpl='iter_{}.pth', out_dir: str,
meta=None, filename_tmpl: str = 'iter_{}.pth',
save_optimizer=True, meta: Optional[Dict] = None,
create_symlink=True): save_optimizer: bool = True,
create_symlink: bool = True) -> None:
"""Save checkpoint to file. """Save checkpoint to file.
Args: Args:
...@@ -261,9 +269,9 @@ class IterBasedRunner(BaseRunner): ...@@ -261,9 +269,9 @@ class IterBasedRunner(BaseRunner):
will be triggered after default hooks. will be triggered after default hooks.
""" """
if checkpoint_config is not None: if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False) checkpoint_config.setdefault('by_epoch', False) # type: ignore
if lr_config is not None: if lr_config is not None:
lr_config.setdefault('by_epoch', False) lr_config.setdefault('by_epoch', False) # type: ignore
if log_config is not None: if log_config is not None:
for info in log_config['hooks']: for info in log_config['hooks']:
info.setdefault('by_epoch', False) info.setdefault('by_epoch', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment