Unverified Commit 335199db authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

Add type hint in lr_updater.py (#1988)

* [Enhance] Add type hint in `lr_updater.py`.

* Fix circle import
parent 23eb359b
# Copyright (c) OpenMMLab. All rights reserved.
import numbers
from math import cos, pi
from typing import Callable, List, Optional, Union
import mmcv
from mmcv import runner
from .hook import HOOKS, Hook
......@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook):
"""
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.1,
warmup_by_epoch=False):
by_epoch: bool = True,
warmup: Optional[str] = None,
warmup_iters: int = 0,
warmup_ratio: float = 0.1,
warmup_by_epoch: bool = False) -> None:
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
......@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook):
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_iters: Optional[int] = warmup_iters
self.warmup_ratio = warmup_ratio
self.warmup_by_epoch = warmup_by_epoch
if self.warmup_by_epoch:
self.warmup_epochs = self.warmup_iters
self.warmup_epochs: Optional[int] = self.warmup_iters
self.warmup_iters = None
else:
self.warmup_epochs = None
self.base_lr = [] # initial lr for all param groups
self.regular_lr = [] # expected lr if no warming up is performed
self.base_lr: Union[list, dict] = [] # initial lr for all param groups
self.regular_lr: list = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups):
if isinstance(runner.optimizer, dict):
......@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook):
lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
raise NotImplementedError
def get_regular_lr(self, runner):
def get_regular_lr(self, runner: 'runner.BaseRunner'):
if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
......@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook):
else:
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters):
def get_warmup_lr(self, cur_iters: int):
def _get_warmup_lr(cur_iters, regular_lr):
if self.warmup == 'constant':
......@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook):
else:
return _get_warmup_lr(cur_iters, self.regular_lr)
def before_run(self, runner):
def before_run(self, runner: 'runner.BaseRunner'):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
if isinstance(runner.optimizer, dict):
......@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook):
group['initial_lr'] for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
def before_train_epoch(self, runner: 'runner.BaseRunner'):
if self.warmup_iters is None:
epoch_len = len(runner.data_loader)
self.warmup_iters = self.warmup_epochs * epoch_len
epoch_len = len(runner.data_loader) # type: ignore
self.warmup_iters = self.warmup_epochs * epoch_len # type: ignore
if not self.by_epoch:
return
......@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook):
self.regular_lr = self.get_regular_lr(runner)
self._set_lr(runner, self.regular_lr)
def before_train_iter(self, runner):
def before_train_iter(self, runner: 'runner.BaseRunner'):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
......@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
step (int | list[int]): Step to decay the LR. If an int value is given,
regard it as the decay interval. If a list is given, decay LR at
these steps.
gamma (float, optional): Decay LR ratio. Default: 0.1.
gamma (float): Decay LR ratio. Defaults to 0.1.
min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value. If None
is given, we don't perform lr clipping. Default: None.
"""
def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
def __init__(self,
step: Union[int, List[int]],
gamma: float = 0.1,
min_lr: Optional[float] = None,
**kwargs) -> None:
if isinstance(step, list):
assert mmcv.is_list_of(step, int)
assert all([s > 0 for s in step])
......@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self.min_lr = min_lr
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term
......@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module()
class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
def __init__(self, gamma: float, **kwargs) -> None:
self.gamma = gamma
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * self.gamma**progress
......@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module()
class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs):
def __init__(self,
power: float = 1.,
min_lr: float = 0.,
**kwargs) -> None:
self.power = power
self.min_lr = min_lr
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
......@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module()
class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
def __init__(self, gamma: float, power: float = 1., **kwargs) -> None:
self.gamma = gamma
self.power = power
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * (1 + self.gamma * progress)**(-self.power)
......@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
"""
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
def __init__(self,
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
......@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
target_lr = self.min_lr # type:ignore
return annealing_cos(base_lr, target_lr, progress / max_progress)
......@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
"""
def __init__(self,
start_percent=0.75,
min_lr=None,
min_lr_ratio=None,
**kwargs):
start_percent: float = 0.75,
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None)
if start_percent < 0 or start_percent > 1 or not isinstance(
start_percent, float):
......@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch:
start = round(runner.max_epochs * self.start_percent)
progress = runner.epoch - start
......@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
target_lr = self.min_lr # type:ignore
if progress < 0:
return base_lr
......@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
Args:
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float], optional): Restart weights at each
restart iteration. Default: [1].
restart_weights (list[float]): Restart weights at each
restart iteration. Defaults to [1].
min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified.
......@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
"""
def __init__(self,
periods,
restart_weights=[1],
min_lr=None,
min_lr_ratio=None,
**kwargs):
periods: List[int],
restart_weights: List[float] = [1],
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None)
self.periods = periods
self.min_lr = min_lr
......@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
]
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch:
progress = runner.epoch
else:
......@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
target_lr = self.min_lr # type:ignore
idx = get_position_from_periods(progress, self.cumulative_periods)
current_weight = self.restart_weights[idx]
......@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
return annealing_cos(base_lr, target_lr, alpha, current_weight)
def get_position_from_periods(iteration, cumulative_periods):
def get_position_from_periods(iteration: int, cumulative_periods: List[int]):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
......@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
"""
def __init__(self,
by_epoch=False,
target_ratio=(10, 1e-4),
cyclic_times=1,
step_ratio_up=0.4,
anneal_strategy='cos',
gamma=1,
**kwargs):
by_epoch: bool = False,
target_ratio: Union[float, tuple] = (10, 1e-4),
cyclic_times: int = 1,
step_ratio_up: float = 0.4,
anneal_strategy: str = 'cos',
gamma: float = 1,
**kwargs) -> None:
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple):
......@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
self.step_ratio_up = step_ratio_up
self.gamma = gamma
self.max_iter_per_phase = None
self.lr_phases = [] # init lr_phases
self.lr_phases: list = [] # init lr_phases
# validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
self.anneal_func: Callable[[float, float, float],
float] = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
......@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
'currently only support "by_epoch" = False'
super().__init__(by_epoch, **kwargs)
def before_run(self, runner):
def before_run(self, runner: 'runner.BaseRunner'):
super().before_run(runner)
# initiate lr_phases
# total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * self.max_iter_per_phase)
iter_up_phase = int(self.step_ratio_up *
self.max_iter_per_phase) # type:ignore
self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]])
self.lr_phases.append([
iter_up_phase, self.max_iter_per_phase, self.target_ratio[0],
self.target_ratio[1]
])
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter % self.max_iter_per_phase
curr_cycle = runner.iter // self.max_iter_per_phase
# Update weight decay
......@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
"""
def __init__(self,
max_lr,
total_steps=None,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25,
final_div_factor=1e4,
three_phase=False,
**kwargs):
max_lr: Union[float, List],
total_steps: Optional[int] = None,
pct_start: float = 0.3,
anneal_strategy: str = 'cos',
div_factor: float = 25,
final_div_factor: float = 1e4,
three_phase: bool = False,
**kwargs) -> None:
# validate by_epoch, currently only support by_epoch = False
if 'by_epoch' not in kwargs:
kwargs['by_epoch'] = False
......@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
self.anneal_func: Callable[[float, float, float],
float] = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
self.div_factor = div_factor
self.final_div_factor = final_div_factor
self.three_phase = three_phase
self.lr_phases = [] # init lr_phases
self.lr_phases: list = [] # init lr_phases
super().__init__(**kwargs)
def before_run(self, runner):
def before_run(self, runner: 'runner.BaseRunner'):
if hasattr(self, 'total_steps'):
total_steps = self.total_steps
else:
......@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self.lr_phases.append(
[total_steps - 1, self.div_factor, 1 / self.final_div_factor])
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter
start_iter = 0
for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
......@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
"""
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
def __init__(self,
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
......@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
target_lr = self.min_lr # type:ignore
return annealing_linear(base_lr, target_lr, progress / max_progress)
def annealing_cos(start, end, factor, weight=1):
def annealing_cos(start: float,
end: float,
factor: float,
weight: float = 1) -> float:
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
......@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1):
return end + 0.5 * weight * (start - end) * cos_out
def annealing_linear(start, end, factor):
def annealing_linear(start: float, end: float, factor: float) -> float:
"""Calculate annealing linear learning rate.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment