Unverified Commit bb710675 authored by Tri Wahyu Utomo's avatar Tri Wahyu Utomo Committed by GitHub
Browse files

Add type hints for mmcv/runner (#1991)



* add typehints

* fix typehints

* fix _init_rule type hint

* update error message

* simplify type hint

* minor refinement

* minor refinement
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
parent f5425ab7
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings import warnings
from typing import Optional
from mmcv.fileio import FileClient from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only from ..dist_utils import allreduce_params, master_only
...@@ -49,14 +50,14 @@ class CheckpointHook(Hook): ...@@ -49,14 +50,14 @@ class CheckpointHook(Hook):
""" """
def __init__(self, def __init__(self,
interval=-1, interval: int = -1,
by_epoch=True, by_epoch: bool = True,
save_optimizer=True, save_optimizer: bool = True,
out_dir=None, out_dir: Optional[str] = None,
max_keep_ckpts=-1, max_keep_ckpts: int = -1,
save_last=True, save_last: bool = True,
sync_buffer=False, sync_buffer: bool = False,
file_client_args=None, file_client_args: Optional[dict] = None,
**kwargs): **kwargs):
self.interval = interval self.interval = interval
self.by_epoch = by_epoch self.by_epoch = by_epoch
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
@HOOKS.register_module() @HOOKS.register_module()
class ClosureHook(Hook): class ClosureHook(Hook):
def __init__(self, fn_name, fn): def __init__(self, fn_name: str, fn: Callable):
assert hasattr(self, fn_name) assert hasattr(self, fn_name)
assert callable(fn) assert callable(fn)
setattr(self, fn_name, fn) setattr(self, fn_name, fn)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from ...parallel import is_module_wrapper from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook from ..hooks.hook import HOOKS, Hook
...@@ -23,14 +25,14 @@ class EMAHook(Hook): ...@@ -23,14 +25,14 @@ class EMAHook(Hook):
Defaults to 1. Defaults to 1.
warm_up (int): During first warm_up steps, we may use smaller momentum warm_up (int): During first warm_up steps, we may use smaller momentum
to update ema parameters more slowly. Defaults to 100. to update ema parameters more slowly. Defaults to 100.
resume_from (str): The checkpoint path. Defaults to None. resume_from (str, optional): The checkpoint path. Defaults to None.
""" """
def __init__(self, def __init__(self,
momentum=0.0002, momentum: float = 0.0002,
interval=1, interval: int = 1,
warm_up=100, warm_up: int = 100,
resume_from=None): resume_from: Optional[str] = None):
assert isinstance(interval, int) and interval > 0 assert isinstance(interval, int) and interval > 0
self.warm_up = warm_up self.warm_up = warm_up
self.interval = interval self.interval = interval
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import os.path as osp import os.path as osp
import warnings import warnings
from math import inf from math import inf
from typing import Callable, List, Optional
import torch.distributed as dist import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
...@@ -83,17 +84,17 @@ class EvalHook(Hook): ...@@ -83,17 +84,17 @@ class EvalHook(Hook):
_default_less_keys = ['loss'] _default_less_keys = ['loss']
def __init__(self, def __init__(self,
dataloader, dataloader: DataLoader,
start=None, start: Optional[int] = None,
interval=1, interval: int = 1,
by_epoch=True, by_epoch: bool = True,
save_best=None, save_best: Optional[str] = None,
rule=None, rule: Optional[str] = None,
test_fn=None, test_fn: Optional[Callable] = None,
greater_keys=None, greater_keys: Optional[List[str]] = None,
less_keys=None, less_keys: Optional[List[str]] = None,
out_dir=None, out_dir: Optional[str] = None,
file_client_args=None, file_client_args: Optional[dict] = None,
**eval_kwargs): **eval_kwargs):
if not isinstance(dataloader, DataLoader): if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, ' raise TypeError(f'dataloader must be a pytorch DataLoader, '
...@@ -131,6 +132,7 @@ class EvalHook(Hook): ...@@ -131,6 +132,7 @@ class EvalHook(Hook):
self.greater_keys = self._default_greater_keys self.greater_keys = self._default_greater_keys
else: else:
if not isinstance(greater_keys, (list, tuple)): if not isinstance(greater_keys, (list, tuple)):
assert isinstance(greater_keys, str)
greater_keys = (greater_keys, ) greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str) assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys self.greater_keys = greater_keys
...@@ -139,6 +141,7 @@ class EvalHook(Hook): ...@@ -139,6 +141,7 @@ class EvalHook(Hook):
self.less_keys = self._default_less_keys self.less_keys = self._default_less_keys
else: else:
if not isinstance(less_keys, (list, tuple)): if not isinstance(less_keys, (list, tuple)):
assert isinstance(greater_keys, str)
less_keys = (less_keys, ) less_keys = (less_keys, )
assert is_seq_of(less_keys, str) assert is_seq_of(less_keys, str)
self.less_keys = less_keys self.less_keys = less_keys
...@@ -150,7 +153,7 @@ class EvalHook(Hook): ...@@ -150,7 +153,7 @@ class EvalHook(Hook):
self.out_dir = out_dir self.out_dir = out_dir
self.file_client_args = file_client_args self.file_client_args = file_client_args
def _init_rule(self, rule, key_indicator): def _init_rule(self, rule: Optional[str], key_indicator: str):
"""Initialize rule, key_indicator, comparison_func, and best score. """Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator Here is the rule to determine which rule is used for key indicator
...@@ -178,6 +181,7 @@ class EvalHook(Hook): ...@@ -178,6 +181,7 @@ class EvalHook(Hook):
if key_indicator != 'auto': if key_indicator != 'auto':
# `_lc` here means we use the lower case of keys for # `_lc` here means we use the lower case of keys for
# case-insensitive matching # case-insensitive matching
assert isinstance(key_indicator, str)
key_indicator_lc = key_indicator.lower() key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys] greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys] less_keys = [key.lower() for key in self.less_keys]
...@@ -439,20 +443,20 @@ class DistEvalHook(EvalHook): ...@@ -439,20 +443,20 @@ class DistEvalHook(EvalHook):
""" """
def __init__(self, def __init__(self,
dataloader, dataloader: DataLoader,
start=None, start: Optional[int] = None,
interval=1, interval: int = 1,
by_epoch=True, by_epoch: bool = True,
save_best=None, save_best: Optional[str] = None,
rule=None, rule: Optional[str] = None,
test_fn=None, test_fn: Optional[Callable] = None,
greater_keys=None, greater_keys: Optional[List[str]] = None,
less_keys=None, less_keys: Optional[List[str]] = None,
broadcast_bn_buffer=True, broadcast_bn_buffer: bool = True,
tmpdir=None, tmpdir: Optional[str] = None,
gpu_collect=False, gpu_collect: bool = False,
out_dir=None, out_dir: Optional[str] = None,
file_client_args=None, file_client_args: Optional[dict] = None,
**eval_kwargs): **eval_kwargs):
if test_fn is None: if test_fn is None:
......
...@@ -705,7 +705,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -705,7 +705,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
def annealing_cos(start: float, def annealing_cos(start: float,
end: float, end: float,
factor: float, factor: float,
weight: float = 1) -> float: weight: float = 1.) -> float:
"""Calculate annealing cos learning rate. """Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
......
...@@ -7,7 +7,10 @@ from .hook import HOOKS, Hook ...@@ -7,7 +7,10 @@ from .hook import HOOKS, Hook
@HOOKS.register_module() @HOOKS.register_module()
class EmptyCacheHook(Hook): class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): def __init__(self,
before_epoch: bool = False,
after_epoch: bool = True,
after_iter: bool = False):
self._before_epoch = before_epoch self._before_epoch = before_epoch
self._after_epoch = after_epoch self._after_epoch = after_epoch
self._after_iter = after_iter self._after_iter = after_iter
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Tuple, Union
import mmcv import mmcv
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .lr_updater import annealing_cos, annealing_linear, format_param from .lr_updater import annealing_cos, annealing_linear, format_param
...@@ -7,10 +9,10 @@ from .lr_updater import annealing_cos, annealing_linear, format_param ...@@ -7,10 +9,10 @@ from .lr_updater import annealing_cos, annealing_linear, format_param
class MomentumUpdaterHook(Hook): class MomentumUpdaterHook(Hook):
def __init__(self, def __init__(self,
by_epoch=True, by_epoch: bool = True,
warmup=None, warmup: Optional[str] = None,
warmup_iters=0, warmup_iters: int = 0,
warmup_ratio=0.9): warmup_ratio: float = 0.9):
# validate the "warmup" argument # validate the "warmup" argument
if warmup is not None: if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']: if warmup not in ['constant', 'linear', 'exp']:
...@@ -28,9 +30,10 @@ class MomentumUpdaterHook(Hook): ...@@ -28,9 +30,10 @@ class MomentumUpdaterHook(Hook):
self.warmup_iters = warmup_iters self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio self.warmup_ratio = warmup_ratio
self.base_momentum = [] # initial momentum for all param groups # initial momentum for all param groups
self.regular_momentum = [ self.base_momentum: Union[list, dict] = []
] # expected momentum if no warming up is performed # expected momentum if no warming up is performed
self.regular_momentum: Union[list, dict] = []
def _set_momentum(self, runner, momentum_groups): def _set_momentum(self, runner, momentum_groups):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
...@@ -49,26 +52,30 @@ class MomentumUpdaterHook(Hook): ...@@ -49,26 +52,30 @@ class MomentumUpdaterHook(Hook):
elif 'betas' in param_group.keys(): elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1]) param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum) -> float:
raise NotImplementedError raise NotImplementedError
def get_regular_momentum(self, runner): def get_regular_momentum(self, runner) -> Union[list, Dict[str, list]]:
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
momentum_groups = {} assert isinstance(self.base_momentum, dict)
momentum_groups: Dict[str, List[float]] = {}
for k in runner.optimizer.keys(): for k in runner.optimizer.keys():
_momentum_group = [ _momentum_group: List[float] = [
self.get_momentum(runner, _base_momentum) self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum[k] for _base_momentum in self.base_momentum[k]
] ]
momentum_groups.update({k: _momentum_group}) momentum_groups.update({k: _momentum_group})
return momentum_groups return momentum_groups
else: else:
assert isinstance(self.base_momentum, list)
return [ return [
self.get_momentum(runner, _base_momentum) self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum for _base_momentum in self.base_momentum
] ]
def get_warmup_momentum(self, cur_iters): def get_warmup_momentum(
self,
cur_iters: int) -> Union[List[float], Dict[str, List[float]]]:
def _get_warmup_momentum(cur_iters, regular_momentum): def _get_warmup_momentum(cur_iters, regular_momentum):
if self.warmup == 'constant': if self.warmup == 'constant':
...@@ -87,6 +94,10 @@ class MomentumUpdaterHook(Hook): ...@@ -87,6 +94,10 @@ class MomentumUpdaterHook(Hook):
warmup_momentum = [ warmup_momentum = [
_momentum / k for _momentum in regular_momentum _momentum / k for _momentum in regular_momentum
] ]
else:
raise ValueError(
'Expected values of `self.warmup` to be "constant", '
f'"linear", or "exp", got {self.warmup}')
return warmup_momentum return warmup_momentum
if isinstance(self.regular_momentum, dict): if isinstance(self.regular_momentum, dict):
...@@ -165,7 +176,11 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -165,7 +176,11 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
Default: None. Default: None.
""" """
def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs): def __init__(self,
step: Union[int, List[int]],
gamma: float = 0.5,
min_momentum: Optional[float] = None,
**kwargs):
if isinstance(step, list): if isinstance(step, list):
assert mmcv.is_list_of(step, int) assert mmcv.is_list_of(step, int)
assert all([s > 0 for s in step]) assert all([s > 0 for s in step])
...@@ -178,7 +193,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -178,7 +193,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self.min_momentum = min_momentum self.min_momentum = min_momentum
super().__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum: float) -> float:
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term # calculate exponential term
...@@ -210,13 +225,16 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -210,13 +225,16 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
should be specified. Default: None. should be specified. Default: None.
""" """
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): def __init__(self,
min_momentum: Optional[float] = None,
min_momentum_ratio: Optional[float] = None,
**kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None) assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio self.min_momentum_ratio = min_momentum_ratio
super().__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum: float) -> float:
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
max_progress = runner.max_epochs max_progress = runner.max_epochs
...@@ -226,6 +244,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -226,6 +244,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
if self.min_momentum_ratio is not None: if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio target_momentum = base_momentum * self.min_momentum_ratio
else: else:
assert self.min_momentum is not None
target_momentum = self.min_momentum target_momentum = self.min_momentum
return annealing_cos(base_momentum, target_momentum, return annealing_cos(base_momentum, target_momentum,
progress / max_progress) progress / max_progress)
...@@ -243,13 +262,16 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -243,13 +262,16 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
should be specified. Default: None. should be specified. Default: None.
""" """
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): def __init__(self,
min_momentum: Optional[float] = None,
min_momentum_ratio: Optional[float] = None,
**kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None) assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio self.min_momentum_ratio = min_momentum_ratio
super().__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum: float) -> float:
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
max_progress = runner.max_epochs max_progress = runner.max_epochs
...@@ -259,6 +281,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -259,6 +281,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
if self.min_momentum_ratio is not None: if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio target_momentum = base_momentum * self.min_momentum_ratio
else: else:
assert self.min_momentum is not None
target_momentum = self.min_momentum target_momentum = self.min_momentum
return annealing_linear(base_momentum, target_momentum, return annealing_linear(base_momentum, target_momentum,
progress / max_progress) progress / max_progress)
...@@ -291,12 +314,12 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -291,12 +314,12 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
""" """
def __init__(self, def __init__(self,
by_epoch=False, by_epoch: bool = False,
target_ratio=(0.85 / 0.95, 1), target_ratio: Tuple[float, float] = (0.85 / 0.95, 1.),
cyclic_times=1, cyclic_times: int = 1,
step_ratio_up=0.4, step_ratio_up: float = 0.4,
anneal_strategy='cos', anneal_strategy: str = 'cos',
gamma=1, gamma: float = 1.,
**kwargs): **kwargs):
if isinstance(target_ratio, float): if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5) target_ratio = (target_ratio, target_ratio / 1e5)
...@@ -316,8 +339,9 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -316,8 +339,9 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
self.cyclic_times = cyclic_times self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up self.step_ratio_up = step_ratio_up
self.gamma = gamma self.gamma = gamma
self.momentum_phases = [] # init momentum_phases self.momentum_phases: List[list] = [] # init momentum_phases
self.anneal_func: Callable[[float, float, float], float]
if anneal_strategy not in ['cos', 'linear']: if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or ' raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}') f'"linear", instead got {anneal_strategy}')
...@@ -344,7 +368,7 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -344,7 +368,7 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
self.target_ratio[1] self.target_ratio[1]
]) ])
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum: float) -> float:
curr_iter = runner.iter % self.max_iter_per_phase curr_iter = runner.iter % self.max_iter_per_phase
curr_cycle = runner.iter // self.max_iter_per_phase curr_cycle = runner.iter // self.max_iter_per_phase
scale = self.gamma**curr_cycle scale = self.gamma**curr_cycle
...@@ -366,6 +390,8 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -366,6 +390,8 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
return self.anneal_func(base_momentum * start_ratio, return self.anneal_func(base_momentum * start_ratio,
base_momentum * end_ratio, base_momentum * end_ratio,
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
raise RuntimeError('The method should return in the for-loop and '
'should not be executed until this')
@HOOKS.register_module() @HOOKS.register_module()
...@@ -404,11 +430,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -404,11 +430,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
""" """
def __init__(self, def __init__(self,
base_momentum=0.85, base_momentum: Union[float, list, dict] = 0.85,
max_momentum=0.95, max_momentum: Union[float, list, dict] = 0.95,
pct_start=0.3, pct_start: float = 0.3,
anneal_strategy='cos', anneal_strategy: str = 'cos',
three_phase=False, three_phase: bool = False,
**kwargs): **kwargs):
# validate by_epoch, currently only support by_epoch=False # validate by_epoch, currently only support by_epoch=False
if 'by_epoch' not in kwargs: if 'by_epoch' not in kwargs:
...@@ -430,6 +456,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -430,6 +456,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
f'got {pct_start}') f'got {pct_start}')
self.pct_start = pct_start self.pct_start = pct_start
# validate anneal_strategy # validate anneal_strategy
self.anneal_func: Callable[[float, float, float], float]
if anneal_strategy not in ['cos', 'linear']: if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must by one of "cos" or ' raise ValueError('anneal_strategy must by one of "cos" or '
f'"linear", instead got {anneal_strategy}') f'"linear", instead got {anneal_strategy}')
...@@ -438,7 +465,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -438,7 +465,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
elif anneal_strategy == 'linear': elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear self.anneal_func = annealing_linear
self.three_phase = three_phase self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases self.momentum_phases: List[dict] = [] # init momentum_phases
super().__init__(**kwargs) super().__init__(**kwargs)
def before_run(self, runner): def before_run(self, runner):
...@@ -535,9 +562,10 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -535,9 +562,10 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
elif 'betas' in param_group.keys(): elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1]) param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, param_group): def get_momentum(self, runner, param_group: Dict[str, float]) -> float:
curr_iter = runner.iter curr_iter = runner.iter
start_iter = 0 start_iter = 0
momentum = 0.
for i, phase in enumerate(self.momentum_phases): for i, phase in enumerate(self.momentum_phases):
end_iter = phase['end_iter'] end_iter = phase['end_iter']
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
......
...@@ -3,7 +3,10 @@ import copy ...@@ -3,7 +3,10 @@ import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from typing import Optional, Union
import torch.nn as nn
from torch import Tensor
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
...@@ -39,7 +42,9 @@ class OptimizerHook(Hook): ...@@ -39,7 +42,9 @@ class OptimizerHook(Hook):
Default: False. Default: False.
""" """
def __init__(self, grad_clip=None, detect_anomalous_params=False): def __init__(self,
grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params self.detect_anomalous_params = detect_anomalous_params
...@@ -63,7 +68,7 @@ class OptimizerHook(Hook): ...@@ -63,7 +68,7 @@ class OptimizerHook(Hook):
runner.outputs['num_samples']) runner.outputs['num_samples'])
runner.optimizer.step() runner.optimizer.step()
def detect_anomalous_parameters(self, loss, runner): def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
logger = runner.logger logger = runner.logger
parameters_in_graph = set() parameters_in_graph = set()
visited = set() visited = set()
...@@ -109,7 +114,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook): ...@@ -109,7 +114,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
>>> optim_hook = OptimizerHook() >>> optim_hook = OptimizerHook()
""" """
def __init__(self, cumulative_iters=1, **kwargs): def __init__(self, cumulative_iters: int = 1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
...@@ -121,7 +126,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook): ...@@ -121,7 +126,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
self.remainder_iters = 0 self.remainder_iters = 0
self.initialized = False self.initialized = False
def has_batch_norm(self, module): def has_batch_norm(self, module: nn.Module) -> bool:
if isinstance(module, _BatchNorm): if isinstance(module, _BatchNorm):
return True return True
for m in module.children(): for m in module.children():
...@@ -208,11 +213,11 @@ if (TORCH_VERSION != 'parrots' ...@@ -208,11 +213,11 @@ if (TORCH_VERSION != 'parrots'
""" """
def __init__(self, def __init__(self,
grad_clip=None, grad_clip: Optional[dict] = None,
coalesce=True, coalesce: bool = True,
bucket_size_mb=-1, bucket_size_mb: int = -1,
loss_scale=512., loss_scale: Union[float, str, dict] = 512.,
distributed=True): distributed: bool = True):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.coalesce = coalesce self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb self.bucket_size_mb = bucket_size_mb
...@@ -229,7 +234,7 @@ if (TORCH_VERSION != 'parrots' ...@@ -229,7 +234,7 @@ if (TORCH_VERSION != 'parrots'
raise ValueError('loss_scale must be of type float, dict, or ' raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}') f'"dynamic", got {loss_scale}')
def before_run(self, runner): def before_run(self, runner) -> None:
"""Preparing steps before Mixed Precision Training.""" """Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16 # wrap model mode to fp16
wrap_fp16_model(runner.model) wrap_fp16_model(runner.model)
...@@ -238,7 +243,8 @@ if (TORCH_VERSION != 'parrots' ...@@ -238,7 +243,8 @@ if (TORCH_VERSION != 'parrots'
scaler_state_dict = runner.meta['fp16']['loss_scaler'] scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict) self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights): def copy_grads_to_fp32(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy gradients from fp16 model to fp32 weight copy.""" """Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights, for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()): fp16_net.parameters()):
...@@ -248,13 +254,14 @@ if (TORCH_VERSION != 'parrots' ...@@ -248,13 +254,14 @@ if (TORCH_VERSION != 'parrots'
fp32_param.size()) fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad) fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights): def copy_params_to_fp16(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy updated params from fp32 weight copy to fp16 model.""" """Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(), for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights): fp32_weights):
fp16_param.data.copy_(fp32_param.data) fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner): def after_train_iter(self, runner) -> None:
"""Backward optimization steps for Mixed Precision Training. For """Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer to dynamic loss scaling, please refer to
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
...@@ -299,7 +306,7 @@ if (TORCH_VERSION != 'parrots' ...@@ -299,7 +306,7 @@ if (TORCH_VERSION != 'parrots'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def after_train_iter(self, runner): def after_train_iter(self, runner) -> None:
if not self.initialized: if not self.initialized:
self._init(runner) self._init(runner)
...@@ -363,11 +370,11 @@ else: ...@@ -363,11 +370,11 @@ else:
""" """
def __init__(self, def __init__(self,
grad_clip=None, grad_clip: Optional[dict] = None,
coalesce=True, coalesce: bool = True,
bucket_size_mb=-1, bucket_size_mb: int = -1,
loss_scale=512., loss_scale: Union[float, str, dict] = 512.,
distributed=True): distributed: bool = True):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.coalesce = coalesce self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb self.bucket_size_mb = bucket_size_mb
...@@ -383,7 +390,7 @@ else: ...@@ -383,7 +390,7 @@ else:
raise ValueError('loss_scale must be of type float, dict, or ' raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}') f'"dynamic", got {loss_scale}')
def before_run(self, runner): def before_run(self, runner) -> None:
"""Preparing steps before Mixed Precision Training. """Preparing steps before Mixed Precision Training.
1. Make a master copy of fp32 weights for optimization. 1. Make a master copy of fp32 weights for optimization.
...@@ -393,7 +400,7 @@ else: ...@@ -393,7 +400,7 @@ else:
old_groups = runner.optimizer.param_groups old_groups = runner.optimizer.param_groups
runner.optimizer.param_groups = copy.deepcopy( runner.optimizer.param_groups = copy.deepcopy(
runner.optimizer.param_groups) runner.optimizer.param_groups)
state = defaultdict(dict) state: defaultdict = defaultdict(dict)
p_map = { p_map = {
old_p: p old_p: p
for old_p, p in zip( for old_p, p in zip(
...@@ -411,7 +418,8 @@ else: ...@@ -411,7 +418,8 @@ else:
scaler_state_dict = runner.meta['fp16']['loss_scaler'] scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict) self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights): def copy_grads_to_fp32(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy gradients from fp16 model to fp32 weight copy.""" """Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights, for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()): fp16_net.parameters()):
...@@ -421,13 +429,14 @@ else: ...@@ -421,13 +429,14 @@ else:
fp32_param.size()) fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad) fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights): def copy_params_to_fp16(self, fp16_net: nn.Module,
fp32_weights: Tensor) -> None:
"""Copy updated params from fp32 weight copy to fp16 model.""" """Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(), for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights): fp32_weights):
fp16_param.data.copy_(fp32_param.data) fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner): def after_train_iter(self, runner) -> None:
"""Backward optimization steps for Mixed Precision Training. For """Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer `loss_scalar.py` dynamic loss scaling, please refer `loss_scalar.py`
...@@ -491,7 +500,7 @@ else: ...@@ -491,7 +500,7 @@ else:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def after_train_iter(self, runner): def after_train_iter(self, runner) -> None:
if not self.initialized: if not self.initialized:
self._init(runner) self._init(runner)
......
...@@ -13,7 +13,7 @@ class SyncBuffersHook(Hook): ...@@ -13,7 +13,7 @@ class SyncBuffersHook(Hook):
effective only for distributed training. Defaults to True. effective only for distributed training. Defaults to True.
""" """
def __init__(self, distributed=True): def __init__(self, distributed: bool = True):
self.distributed = distributed self.distributed = distributed
def after_epoch(self, runner): def after_epoch(self, runner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment