Unverified Commit 335199db authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

Add type hint in lr_updater.py (#1988)

* [Enhance] Add type hint in `lr_updater.py`.

* Fix circle import
parent 23eb359b
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numbers import numbers
from math import cos, pi from math import cos, pi
from typing import Callable, List, Optional, Union
import mmcv import mmcv
from mmcv import runner
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook): ...@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook):
""" """
def __init__(self, def __init__(self,
by_epoch=True, by_epoch: bool = True,
warmup=None, warmup: Optional[str] = None,
warmup_iters=0, warmup_iters: int = 0,
warmup_ratio=0.1, warmup_ratio: float = 0.1,
warmup_by_epoch=False): warmup_by_epoch: bool = False) -> None:
# validate the "warmup" argument # validate the "warmup" argument
if warmup is not None: if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']: if warmup not in ['constant', 'linear', 'exp']:
...@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook): ...@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook):
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.warmup = warmup self.warmup = warmup
self.warmup_iters = warmup_iters self.warmup_iters: Optional[int] = warmup_iters
self.warmup_ratio = warmup_ratio self.warmup_ratio = warmup_ratio
self.warmup_by_epoch = warmup_by_epoch self.warmup_by_epoch = warmup_by_epoch
if self.warmup_by_epoch: if self.warmup_by_epoch:
self.warmup_epochs = self.warmup_iters self.warmup_epochs: Optional[int] = self.warmup_iters
self.warmup_iters = None self.warmup_iters = None
else: else:
self.warmup_epochs = None self.warmup_epochs = None
self.base_lr = [] # initial lr for all param groups self.base_lr: Union[list, dict] = [] # initial lr for all param groups
self.regular_lr = [] # expected lr if no warming up is performed self.regular_lr: list = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups): def _set_lr(self, runner, lr_groups):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
...@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook): ...@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook):
lr_groups): lr_groups):
param_group['lr'] = lr param_group['lr'] = lr
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
raise NotImplementedError raise NotImplementedError
def get_regular_lr(self, runner): def get_regular_lr(self, runner: 'runner.BaseRunner'):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
lr_groups = {} lr_groups = {}
for k in runner.optimizer.keys(): for k in runner.optimizer.keys():
...@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook): ...@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook):
else: else:
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters): def get_warmup_lr(self, cur_iters: int):
def _get_warmup_lr(cur_iters, regular_lr): def _get_warmup_lr(cur_iters, regular_lr):
if self.warmup == 'constant': if self.warmup == 'constant':
...@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook): ...@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook):
else: else:
return _get_warmup_lr(cur_iters, self.regular_lr) return _get_warmup_lr(cur_iters, self.regular_lr)
def before_run(self, runner): def before_run(self, runner: 'runner.BaseRunner'):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params # it will be set according to the optimizer params
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
...@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook): ...@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook):
group['initial_lr'] for group in runner.optimizer.param_groups group['initial_lr'] for group in runner.optimizer.param_groups
] ]
def before_train_epoch(self, runner): def before_train_epoch(self, runner: 'runner.BaseRunner'):
if self.warmup_iters is None: if self.warmup_iters is None:
epoch_len = len(runner.data_loader) epoch_len = len(runner.data_loader) # type: ignore
self.warmup_iters = self.warmup_epochs * epoch_len self.warmup_iters = self.warmup_epochs * epoch_len # type: ignore
if not self.by_epoch: if not self.by_epoch:
return return
...@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook): ...@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook):
self.regular_lr = self.get_regular_lr(runner) self.regular_lr = self.get_regular_lr(runner)
self._set_lr(runner, self.regular_lr) self._set_lr(runner, self.regular_lr)
def before_train_iter(self, runner): def before_train_iter(self, runner: 'runner.BaseRunner'):
cur_iter = runner.iter cur_iter = runner.iter
if not self.by_epoch: if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner) self.regular_lr = self.get_regular_lr(runner)
...@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
step (int | list[int]): Step to decay the LR. If an int value is given, step (int | list[int]): Step to decay the LR. If an int value is given,
regard it as the decay interval. If a list is given, decay LR at regard it as the decay interval. If a list is given, decay LR at
these steps. these steps.
gamma (float, optional): Decay LR ratio. Default: 0.1. gamma (float): Decay LR ratio. Defaults to 0.1.
min_lr (float, optional): Minimum LR value to keep. If LR after decay min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value. If None is lower than `min_lr`, it will be clipped to this value. If None
is given, we don't perform lr clipping. Default: None. is given, we don't perform lr clipping. Default: None.
""" """
def __init__(self, step, gamma=0.1, min_lr=None, **kwargs): def __init__(self,
step: Union[int, List[int]],
gamma: float = 0.1,
min_lr: Optional[float] = None,
**kwargs) -> None:
if isinstance(step, list): if isinstance(step, list):
assert mmcv.is_list_of(step, int) assert mmcv.is_list_of(step, int)
assert all([s > 0 for s in step]) assert all([s > 0 for s in step])
...@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self.min_lr = min_lr self.min_lr = min_lr
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term # calculate exponential term
...@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module() @HOOKS.register_module()
class ExpLrUpdaterHook(LrUpdaterHook): class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs): def __init__(self, gamma: float, **kwargs) -> None:
self.gamma = gamma self.gamma = gamma
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * self.gamma**progress return base_lr * self.gamma**progress
...@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook): ...@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module() @HOOKS.register_module()
class PolyLrUpdaterHook(LrUpdaterHook): class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs): def __init__(self,
power: float = 1.,
min_lr: float = 0.,
**kwargs) -> None:
self.power = power self.power = power
self.min_lr = min_lr self.min_lr = min_lr
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
max_progress = runner.max_epochs max_progress = runner.max_epochs
...@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook): ...@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module() @HOOKS.register_module()
class InvLrUpdaterHook(LrUpdaterHook): class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs): def __init__(self, gamma: float, power: float = 1., **kwargs) -> None:
self.gamma = gamma self.gamma = gamma
self.power = power self.power = power
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * (1 + self.gamma * progress)**(-self.power) return base_lr * (1 + self.gamma * progress)**(-self.power)
...@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None. Default: None.
""" """
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs): def __init__(self,
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
max_progress = runner.max_epochs max_progress = runner.max_epochs
...@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None: if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio target_lr = base_lr * self.min_lr_ratio
else: else:
target_lr = self.min_lr target_lr = self.min_lr # type:ignore
return annealing_cos(base_lr, target_lr, progress / max_progress) return annealing_cos(base_lr, target_lr, progress / max_progress)
...@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
""" """
def __init__(self, def __init__(self,
start_percent=0.75, start_percent: float = 0.75,
min_lr=None, min_lr: Optional[float] = None,
min_lr_ratio=None, min_lr_ratio: Optional[float] = None,
**kwargs): **kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
if start_percent < 0 or start_percent > 1 or not isinstance( if start_percent < 0 or start_percent > 1 or not isinstance(
start_percent, float): start_percent, float):
...@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch: if self.by_epoch:
start = round(runner.max_epochs * self.start_percent) start = round(runner.max_epochs * self.start_percent)
progress = runner.epoch - start progress = runner.epoch - start
...@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None: if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio target_lr = base_lr * self.min_lr_ratio
else: else:
target_lr = self.min_lr target_lr = self.min_lr # type:ignore
if progress < 0: if progress < 0:
return base_lr return base_lr
...@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook): ...@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
Args: Args:
periods (list[int]): Periods for each cosine anneling cycle. periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float], optional): Restart weights at each restart_weights (list[float]): Restart weights at each
restart iteration. Default: [1]. restart iteration. Defaults to [1].
min_lr (float, optional): The minimum lr. Default: None. min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified. Either `min_lr` or `min_lr_ratio` should be specified.
...@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook): ...@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
""" """
def __init__(self, def __init__(self,
periods, periods: List[int],
restart_weights=[1], restart_weights: List[float] = [1],
min_lr=None, min_lr: Optional[float] = None,
min_lr_ratio=None, min_lr_ratio: Optional[float] = None,
**kwargs): **kwargs) -> None:
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.periods = periods self.periods = periods
self.min_lr = min_lr self.min_lr = min_lr
...@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook): ...@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
] ]
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
else: else:
...@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook): ...@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None: if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio target_lr = base_lr * self.min_lr_ratio
else: else:
target_lr = self.min_lr target_lr = self.min_lr # type:ignore
idx = get_position_from_periods(progress, self.cumulative_periods) idx = get_position_from_periods(progress, self.cumulative_periods)
current_weight = self.restart_weights[idx] current_weight = self.restart_weights[idx]
...@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook): ...@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
return annealing_cos(base_lr, target_lr, alpha, current_weight) return annealing_cos(base_lr, target_lr, alpha, current_weight)
def get_position_from_periods(iteration, cumulative_periods): def get_position_from_periods(iteration: int, cumulative_periods: List[int]):
"""Get the position from a period list. """Get the position from a period list.
It will return the index of the right-closest number in the period list. It will return the index of the right-closest number in the period list.
...@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
""" """
def __init__(self, def __init__(self,
by_epoch=False, by_epoch: bool = False,
target_ratio=(10, 1e-4), target_ratio: Union[float, tuple] = (10, 1e-4),
cyclic_times=1, cyclic_times: int = 1,
step_ratio_up=0.4, step_ratio_up: float = 0.4,
anneal_strategy='cos', anneal_strategy: str = 'cos',
gamma=1, gamma: float = 1,
**kwargs): **kwargs) -> None:
if isinstance(target_ratio, float): if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5) target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple): elif isinstance(target_ratio, tuple):
...@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
self.step_ratio_up = step_ratio_up self.step_ratio_up = step_ratio_up
self.gamma = gamma self.gamma = gamma
self.max_iter_per_phase = None self.max_iter_per_phase = None
self.lr_phases = [] # init lr_phases self.lr_phases: list = [] # init lr_phases
# validate anneal_strategy # validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']: if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or ' raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}') f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos': elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos self.anneal_func: Callable[[float, float, float],
float] = annealing_cos
elif anneal_strategy == 'linear': elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear self.anneal_func = annealing_linear
...@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
'currently only support "by_epoch" = False' 'currently only support "by_epoch" = False'
super().__init__(by_epoch, **kwargs) super().__init__(by_epoch, **kwargs)
def before_run(self, runner): def before_run(self, runner: 'runner.BaseRunner'):
super().before_run(runner) super().before_run(runner)
# initiate lr_phases # initiate lr_phases
# total lr_phases are separated as up and down # total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times self.max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * self.max_iter_per_phase) iter_up_phase = int(self.step_ratio_up *
self.max_iter_per_phase) # type:ignore
self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]]) self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]])
self.lr_phases.append([ self.lr_phases.append([
iter_up_phase, self.max_iter_per_phase, self.target_ratio[0], iter_up_phase, self.max_iter_per_phase, self.target_ratio[0],
self.target_ratio[1] self.target_ratio[1]
]) ])
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter % self.max_iter_per_phase curr_iter = runner.iter % self.max_iter_per_phase
curr_cycle = runner.iter // self.max_iter_per_phase curr_cycle = runner.iter // self.max_iter_per_phase
# Update weight decay # Update weight decay
...@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook): ...@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
""" """
def __init__(self, def __init__(self,
max_lr, max_lr: Union[float, List],
total_steps=None, total_steps: Optional[int] = None,
pct_start=0.3, pct_start: float = 0.3,
anneal_strategy='cos', anneal_strategy: str = 'cos',
div_factor=25, div_factor: float = 25,
final_div_factor=1e4, final_div_factor: float = 1e4,
three_phase=False, three_phase: bool = False,
**kwargs): **kwargs) -> None:
# validate by_epoch, currently only support by_epoch = False # validate by_epoch, currently only support by_epoch = False
if 'by_epoch' not in kwargs: if 'by_epoch' not in kwargs:
kwargs['by_epoch'] = False kwargs['by_epoch'] = False
...@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook): ...@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
raise ValueError('anneal_strategy must be one of "cos" or ' raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}') f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos': elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos self.anneal_func: Callable[[float, float, float],
float] = annealing_cos
elif anneal_strategy == 'linear': elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear self.anneal_func = annealing_linear
self.div_factor = div_factor self.div_factor = div_factor
self.final_div_factor = final_div_factor self.final_div_factor = final_div_factor
self.three_phase = three_phase self.three_phase = three_phase
self.lr_phases = [] # init lr_phases self.lr_phases: list = [] # init lr_phases
super().__init__(**kwargs) super().__init__(**kwargs)
def before_run(self, runner): def before_run(self, runner: 'runner.BaseRunner'):
if hasattr(self, 'total_steps'): if hasattr(self, 'total_steps'):
total_steps = self.total_steps total_steps = self.total_steps
else: else:
...@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook): ...@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self.lr_phases.append( self.lr_phases.append(
[total_steps - 1, self.div_factor, 1 / self.final_div_factor]) [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
curr_iter = runner.iter curr_iter = runner.iter
start_iter = 0 start_iter = 0
for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
...@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None. Default: None.
""" """
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs): def __init__(self,
min_lr: Optional[float] = None,
min_lr_ratio: Optional[float] = None,
**kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super().__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner: 'runner.BaseRunner', base_lr: float):
if self.by_epoch: if self.by_epoch:
progress = runner.epoch progress = runner.epoch
max_progress = runner.max_epochs max_progress = runner.max_epochs
...@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook): ...@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
if self.min_lr_ratio is not None: if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio target_lr = base_lr * self.min_lr_ratio
else: else:
target_lr = self.min_lr target_lr = self.min_lr # type:ignore
return annealing_linear(base_lr, target_lr, progress / max_progress) return annealing_linear(base_lr, target_lr, progress / max_progress)
def annealing_cos(start, end, factor, weight=1): def annealing_cos(start: float,
end: float,
factor: float,
weight: float = 1) -> float:
"""Calculate annealing cos learning rate. """Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
...@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1): ...@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1):
return end + 0.5 * weight * (start - end) * cos_out return end + 0.5 * weight * (start - end) * cos_out
def annealing_linear(start, end, factor): def annealing_linear(start: float, end: float, factor: float) -> float:
"""Calculate annealing linear learning rate. """Calculate annealing linear learning rate.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0. Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment