Unverified Commit bb710675 authored by Tri Wahyu Utomo's avatar Tri Wahyu Utomo Committed by GitHub
Browse files

Add type hints for mmcv/runner (#1991)



* add typehints

* fix typehints

* fix _init_rule type hint

* update error message

* simplify type hint

* minor refinement

* minor refinement
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
parent f5425ab7
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional
from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
......@@ -49,14 +50,14 @@ class CheckpointHook(Hook):
"""
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
file_client_args=None,
interval: int = -1,
by_epoch: bool = True,
save_optimizer: bool = True,
out_dir: Optional[str] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from .hook import HOOKS, Hook
@HOOKS.register_module()
class ClosureHook(Hook):
def __init__(self, fn_name, fn):
def __init__(self, fn_name: str, fn: Callable):
assert hasattr(self, fn_name)
assert callable(fn)
setattr(self, fn_name, fn)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook
......@@ -23,14 +25,14 @@ class EMAHook(Hook):
Defaults to 1.
warm_up (int): During first warm_up steps, we may use smaller momentum
to update ema parameters more slowly. Defaults to 100.
resume_from (str): The checkpoint path. Defaults to None.
resume_from (str, optional): The checkpoint path. Defaults to None.
"""
def __init__(self,
momentum=0.0002,
interval=1,
warm_up=100,
resume_from=None):
momentum: float = 0.0002,
interval: int = 1,
warm_up: int = 100,
resume_from: Optional[str] = None):
assert isinstance(interval, int) and interval > 0
self.warm_up = warm_up
self.interval = interval
......
......@@ -2,6 +2,7 @@
import os.path as osp
import warnings
from math import inf
from typing import Callable, List, Optional
import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm
......@@ -83,17 +84,17 @@ class EvalHook(Hook):
_default_less_keys = ['loss']
def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
out_dir=None,
file_client_args=None,
dataloader: DataLoader,
start: Optional[int] = None,
interval: int = 1,
by_epoch: bool = True,
save_best: Optional[str] = None,
rule: Optional[str] = None,
test_fn: Optional[Callable] = None,
greater_keys: Optional[List[str]] = None,
less_keys: Optional[List[str]] = None,
out_dir: Optional[str] = None,
file_client_args: Optional[dict] = None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
......@@ -131,6 +132,7 @@ class EvalHook(Hook):
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
assert isinstance(greater_keys, str)
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys
......@@ -139,6 +141,7 @@ class EvalHook(Hook):
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
assert isinstance(greater_keys, str)
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys
......@@ -150,7 +153,7 @@ class EvalHook(Hook):
self.out_dir = out_dir
self.file_client_args = file_client_args
def _init_rule(self, rule, key_indicator):
def _init_rule(self, rule: Optional[str], key_indicator: str):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
......@@ -178,6 +181,7 @@ class EvalHook(Hook):
if key_indicator != 'auto':
# `_lc` here means we use the lower case of keys for
# case-insensitive matching
assert isinstance(key_indicator, str)
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]
......@@ -439,20 +443,20 @@ class DistEvalHook(EvalHook):
"""
def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
out_dir=None,
file_client_args=None,
dataloader: DataLoader,
start: Optional[int] = None,
interval: int = 1,
by_epoch: bool = True,
save_best: Optional[str] = None,
rule: Optional[str] = None,
test_fn: Optional[Callable] = None,
greater_keys: Optional[List[str]] = None,
less_keys: Optional[List[str]] = None,
broadcast_bn_buffer: bool = True,
tmpdir: Optional[str] = None,
gpu_collect: bool = False,
out_dir: Optional[str] = None,
file_client_args: Optional[dict] = None,
**eval_kwargs):
if test_fn is None:
......
......@@ -705,7 +705,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
def annealing_cos(start: float,
end: float,
factor: float,
weight: float = 1) -> float:
weight: float = 1.) -> float:
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
......
......@@ -7,7 +7,10 @@ from .hook import HOOKS, Hook
@HOOKS.register_module()
class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
def __init__(self,
before_epoch: bool = False,
after_epoch: bool = True,
after_iter: bool = False):
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Tuple, Union
import mmcv
from .hook import HOOKS, Hook
from .lr_updater import annealing_cos, annealing_linear, format_param
......@@ -7,10 +9,10 @@ from .lr_updater import annealing_cos, annealing_linear, format_param
class MomentumUpdaterHook(Hook):
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.9):
by_epoch: bool = True,
warmup: Optional[str] = None,
warmup_iters: int = 0,
warmup_ratio: float = 0.9):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
......@@ -28,9 +30,10 @@ class MomentumUpdaterHook(Hook):
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.base_momentum = [] # initial momentum for all param groups
self.regular_momentum = [
] # expected momentum if no warming up is performed
# initial momentum for all param groups
self.base_momentum: Union[list, dict] = []
# expected momentum if no warming up is performed
self.regular_momentum: Union[list, dict] = []
def _set_momentum(self, runner, momentum_groups):
if isinstance(runner.optimizer, dict):
......@@ -49,26 +52,30 @@ class MomentumUpdaterHook(Hook):
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, base_momentum):
def get_momentum(self, runner, base_momentum) -> float:
raise NotImplementedError
def get_regular_momentum(self, runner):
def get_regular_momentum(self, runner) -> Union[list, Dict[str, list]]:
if isinstance(runner.optimizer, dict):
momentum_groups = {}
assert isinstance(self.base_momentum, dict)
momentum_groups: Dict[str, List[float]] = {}
for k in runner.optimizer.keys():
_momentum_group = [
_momentum_group: List[float] = [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum[k]
]
momentum_groups.update({k: _momentum_group})
return momentum_groups
else:
assert isinstance(self.base_momentum, list)
return [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum
]
def get_warmup_momentum(self, cur_iters):
def get_warmup_momentum(
self,
cur_iters: int) -> Union[List[float], Dict[str, List[float]]]:
def _get_warmup_momentum(cur_iters, regular_momentum):
if self.warmup == 'constant':
......@@ -87,6 +94,10 @@ class MomentumUpdaterHook(Hook):
warmup_momentum = [
_momentum / k for _momentum in regular_momentum
]
else:
raise ValueError(
'Expected values of `self.warmup` to be "constant", '
f'"linear", or "exp", got {self.warmup}')
return warmup_momentum
if isinstance(self.regular_momentum, dict):
......@@ -165,7 +176,11 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
Default: None.
"""
def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
def __init__(self,
step: Union[int, List[int]],
gamma: float = 0.5,
min_momentum: Optional[float] = None,
**kwargs):
if isinstance(step, list):
assert mmcv.is_list_of(step, int)
assert all([s > 0 for s in step])
......@@ -178,7 +193,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self.min_momentum = min_momentum
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
def get_momentum(self, runner, base_momentum: float) -> float:
progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term
......@@ -210,13 +225,16 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
should be specified. Default: None.
"""
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
def __init__(self,
min_momentum: Optional[float] = None,
min_momentum_ratio: Optional[float] = None,
**kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
def get_momentum(self, runner, base_momentum: float) -> float:
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
......@@ -226,6 +244,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio
else:
assert self.min_momentum is not None
target_momentum = self.min_momentum
return annealing_cos(base_momentum, target_momentum,
progress / max_progress)
......@@ -243,13 +262,16 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
should be specified. Default: None.
"""
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
def __init__(self,
min_momentum: Optional[float] = None,
min_momentum_ratio: Optional[float] = None,
**kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
def get_momentum(self, runner, base_momentum: float) -> float:
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
......@@ -259,6 +281,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio
else:
assert self.min_momentum is not None
target_momentum = self.min_momentum
return annealing_linear(base_momentum, target_momentum,
progress / max_progress)
......@@ -291,12 +314,12 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""
def __init__(self,
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4,
anneal_strategy='cos',
gamma=1,
by_epoch: bool = False,
target_ratio: Tuple[float, float] = (0.85 / 0.95, 1.),
cyclic_times: int = 1,
step_ratio_up: float = 0.4,
anneal_strategy: str = 'cos',
gamma: float = 1.,
**kwargs):
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
......@@ -316,8 +339,9 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up
self.gamma = gamma
self.momentum_phases = [] # init momentum_phases
self.momentum_phases: List[list] = [] # init momentum_phases
self.anneal_func: Callable[[float, float, float], float]
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
......@@ -344,7 +368,7 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
self.target_ratio[1]
])
def get_momentum(self, runner, base_momentum):
def get_momentum(self, runner, base_momentum: float) -> float:
curr_iter = runner.iter % self.max_iter_per_phase
curr_cycle = runner.iter // self.max_iter_per_phase
scale = self.gamma**curr_cycle
......@@ -366,6 +390,8 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
return self.anneal_func(base_momentum * start_ratio,
base_momentum * end_ratio,
progress / (end_iter - start_iter))
raise RuntimeError('The method should return in the for-loop and '
'should not be executed until this')
@HOOKS.register_module()
......@@ -404,11 +430,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
"""
def __init__(self,
base_momentum=0.85,
max_momentum=0.95,
pct_start=0.3,
anneal_strategy='cos',
three_phase=False,
base_momentum: Union[float, list, dict] = 0.85,
max_momentum: Union[float, list, dict] = 0.95,
pct_start: float = 0.3,
anneal_strategy: str = 'cos',
three_phase: bool = False,
**kwargs):
# validate by_epoch, currently only support by_epoch=False
if 'by_epoch' not in kwargs:
......@@ -430,6 +456,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
f'got {pct_start}')
self.pct_start = pct_start
# validate anneal_strategy
self.anneal_func: Callable[[float, float, float], float]
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must by one of "cos" or '
f'"linear", instead got {anneal_strategy}')
......@@ -438,7 +465,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases
self.momentum_phases: List[dict] = [] # init momentum_phases
super().__init__(**kwargs)
def before_run(self, runner):
......@@ -535,9 +562,10 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, param_group):
def get_momentum(self, runner, param_group: Dict[str, float]) -> float:
curr_iter = runner.iter
start_iter = 0
momentum = 0.
for i, phase in enumerate(self.momentum_phases):
end_iter = phase['end_iter']
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
......
......@@ -3,7 +3,10 @@ import copy
import logging
from collections import defaultdict
from itertools import chain
from typing import Optional, Union
import torch.nn as nn
from torch import Tensor
from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
......@@ -39,7 +42,9 @@ class OptimizerHook(Hook):
Default: False.
"""
def __init__(self, grad_clip=None, detect_anomalous_params=False):
def __init__(self,
grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False):
self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params
......@@ -63,7 +68,7 @@ class OptimizerHook(Hook):
runner.outputs['num_samples'])
runner.optimizer.step()
def detect_anomalous_parameters(self, loss, runner):
def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
logger = runner.logger
parameters_in_graph = set()
visited = set()
......@@ -109,7 +114,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
>>> optim_hook = OptimizerHook()
"""
def __init__(self, cumulative_iters=1, **kwargs):
def __init__(self, cumulative_iters: int = 1, **kwargs):
super().__init__(**kwargs)
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
......@@ -121,7 +126,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
self.remainder_iters = 0
self.initialized = False
def has_batch_norm(self, module):
def has_batch_norm(self, module: nn.Module) -> bool:
if isinstance(module, _BatchNorm):
return True
for m in module.children():
......@@ -208,11 +213,11 @@ if (TORCH_VERSION != 'parrots'
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
grad_clip: Optional[dict] = None,
coalesce: bool = True,
bucket_size_mb: int = -1,
loss_scale: Union[float, str, dict] = 512.,
distributed: bool = True):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
......@@ -229,7 +234,7 @@ if (TORCH_VERSION != 'parrots'
raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}')
def before_run(self, runner):
def before_run(self, runner) -> None:
"""Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16
wrap_fp16_model(runner.model)
......@@ -238,7 +243,8 @@ if (TORCH_VERSION != 'parrots'
scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
def copy_grads_to_fp32(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
......@@ -248,13 +254,14 @@ if (TORCH_VERSION != 'parrots'
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
def copy_params_to_fp16(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
def after_train_iter(self, runner) -> None:
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer to
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
......@@ -299,7 +306,7 @@ if (TORCH_VERSION != 'parrots'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
def after_train_iter(self, runner) -> None:
if not self.initialized:
self._init(runner)
......@@ -363,11 +370,11 @@ else:
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
grad_clip: Optional[dict] = None,
coalesce: bool = True,
bucket_size_mb: int = -1,
loss_scale: Union[float, str, dict] = 512.,
distributed: bool = True):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
......@@ -383,7 +390,7 @@ else:
raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}')
def before_run(self, runner):
def before_run(self, runner) -> None:
"""Preparing steps before Mixed Precision Training.
1. Make a master copy of fp32 weights for optimization.
......@@ -393,7 +400,7 @@ else:
old_groups = runner.optimizer.param_groups
runner.optimizer.param_groups = copy.deepcopy(
runner.optimizer.param_groups)
state = defaultdict(dict)
state: defaultdict = defaultdict(dict)
p_map = {
old_p: p
for old_p, p in zip(
......@@ -411,7 +418,8 @@ else:
scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
def copy_grads_to_fp32(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
......@@ -421,13 +429,14 @@ else:
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
def copy_params_to_fp16(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
def after_train_iter(self, runner) -> None:
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer `loss_scalar.py`
......@@ -491,7 +500,7 @@ else:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
def after_train_iter(self, runner) -> None:
if not self.initialized:
self._init(runner)
......
......@@ -13,7 +13,7 @@ class SyncBuffersHook(Hook):
effective only for distributed training. Defaults to True.
"""
def __init__(self, distributed=True):
def __init__(self, distributed: bool = True):
self.distributed = distributed
def after_epoch(self, runner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment