Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
335199db
Unverified
Commit
335199db
authored
May 27, 2022
by
WINDSKY45
Committed by
GitHub
May 27, 2022
Browse files
Add type hint in lr_updater.py (#1988)
* [Enhance] Add type hint in `lr_updater.py`. * Fix circle import
parent
23eb359b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
74 deletions
+95
-74
mmcv/runner/hooks/lr_updater.py
mmcv/runner/hooks/lr_updater.py
+95
-74
No files found.
mmcv/runner/hooks/lr_updater.py
View file @
335199db
# Copyright (c) OpenMMLab. All rights reserved.
import
numbers
from
math
import
cos
,
pi
from
typing
import
Callable
,
List
,
Optional
,
Union
import
mmcv
from
mmcv
import
runner
from
.hook
import
HOOKS
,
Hook
...
...
@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook):
"""
def
__init__
(
self
,
by_epoch
=
True
,
warmup
=
None
,
warmup_iters
=
0
,
warmup_ratio
=
0.1
,
warmup_by_epoch
=
False
)
:
by_epoch
:
bool
=
True
,
warmup
:
Optional
[
str
]
=
None
,
warmup_iters
:
int
=
0
,
warmup_ratio
:
float
=
0.1
,
warmup_by_epoch
:
bool
=
False
)
->
None
:
# validate the "warmup" argument
if
warmup
is
not
None
:
if
warmup
not
in
[
'constant'
,
'linear'
,
'exp'
]:
...
...
@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook):
self
.
by_epoch
=
by_epoch
self
.
warmup
=
warmup
self
.
warmup_iters
=
warmup_iters
self
.
warmup_iters
:
Optional
[
int
]
=
warmup_iters
self
.
warmup_ratio
=
warmup_ratio
self
.
warmup_by_epoch
=
warmup_by_epoch
if
self
.
warmup_by_epoch
:
self
.
warmup_epochs
=
self
.
warmup_iters
self
.
warmup_epochs
:
Optional
[
int
]
=
self
.
warmup_iters
self
.
warmup_iters
=
None
else
:
self
.
warmup_epochs
=
None
self
.
base_lr
=
[]
# initial lr for all param groups
self
.
regular_lr
=
[]
# expected lr if no warming up is performed
self
.
base_lr
:
Union
[
list
,
dict
]
=
[]
# initial lr for all param groups
self
.
regular_lr
:
list
=
[]
# expected lr if no warming up is performed
def
_set_lr
(
self
,
runner
,
lr_groups
):
if
isinstance
(
runner
.
optimizer
,
dict
):
...
...
@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook):
lr_groups
):
param_group
[
'lr'
]
=
lr
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
raise
NotImplementedError
def
get_regular_lr
(
self
,
runner
):
def
get_regular_lr
(
self
,
runner
:
'runner.BaseRunner'
):
if
isinstance
(
runner
.
optimizer
,
dict
):
lr_groups
=
{}
for
k
in
runner
.
optimizer
.
keys
():
...
...
@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook):
else
:
return
[
self
.
get_lr
(
runner
,
_base_lr
)
for
_base_lr
in
self
.
base_lr
]
def
get_warmup_lr
(
self
,
cur_iters
):
def
get_warmup_lr
(
self
,
cur_iters
:
int
):
def
_get_warmup_lr
(
cur_iters
,
regular_lr
):
if
self
.
warmup
==
'constant'
:
...
...
@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook):
else
:
return
_get_warmup_lr
(
cur_iters
,
self
.
regular_lr
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
if
isinstance
(
runner
.
optimizer
,
dict
):
...
...
@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook):
group
[
'initial_lr'
]
for
group
in
runner
.
optimizer
.
param_groups
]
def
before_train_epoch
(
self
,
runner
):
def
before_train_epoch
(
self
,
runner
:
'runner.BaseRunner'
):
if
self
.
warmup_iters
is
None
:
epoch_len
=
len
(
runner
.
data_loader
)
self
.
warmup_iters
=
self
.
warmup_epochs
*
epoch_len
epoch_len
=
len
(
runner
.
data_loader
)
# type: ignore
self
.
warmup_iters
=
self
.
warmup_epochs
*
epoch_len
# type: ignore
if
not
self
.
by_epoch
:
return
...
...
@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook):
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
self
.
_set_lr
(
runner
,
self
.
regular_lr
)
def
before_train_iter
(
self
,
runner
):
def
before_train_iter
(
self
,
runner
:
'runner.BaseRunner'
):
cur_iter
=
runner
.
iter
if
not
self
.
by_epoch
:
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
...
...
@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
step (int | list[int]): Step to decay the LR. If an int value is given,
regard it as the decay interval. If a list is given, decay LR at
these steps.
gamma (float
, optional
): Decay LR ratio. Default
:
0.1.
gamma (float): Decay LR ratio. Default
s to
0.1.
min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value. If None
is given, we don't perform lr clipping. Default: None.
"""
def
__init__
(
self
,
step
,
gamma
=
0.1
,
min_lr
=
None
,
**
kwargs
):
def
__init__
(
self
,
step
:
Union
[
int
,
List
[
int
]],
gamma
:
float
=
0.1
,
min_lr
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
if
isinstance
(
step
,
list
):
assert
mmcv
.
is_list_of
(
step
,
int
)
assert
all
([
s
>
0
for
s
in
step
])
...
...
@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self
.
min_lr
=
min_lr
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
# calculate exponential term
...
...
@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
def
__init__
(
self
,
gamma
:
float
,
**
kwargs
)
->
None
:
self
.
gamma
=
gamma
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
return
base_lr
*
self
.
gamma
**
progress
...
...
@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
def
__init__
(
self
,
power
:
float
=
1.
,
min_lr
:
float
=
0.
,
**
kwargs
)
->
None
:
self
.
power
=
power
self
.
min_lr
=
min_lr
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
...
...
@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
class
InvLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
def
__init__
(
self
,
gamma
:
float
,
power
:
float
=
1.
,
**
kwargs
)
->
None
:
self
.
gamma
=
gamma
self
.
power
=
power
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
...
...
@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
"""
def
__init__
(
self
,
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
...
...
@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
return
annealing_cos
(
base_lr
,
target_lr
,
progress
/
max_progress
)
...
...
@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
"""
def
__init__
(
self
,
start_percent
=
0.75
,
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
start_percent
:
float
=
0.75
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
if
start_percent
<
0
or
start_percent
>
1
or
not
isinstance
(
start_percent
,
float
):
...
...
@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
start
=
round
(
runner
.
max_epochs
*
self
.
start_percent
)
progress
=
runner
.
epoch
-
start
...
...
@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
if
progress
<
0
:
return
base_lr
...
...
@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
Args:
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float]
, optional
): Restart weights at each
restart iteration. Default
:
[1].
restart_weights (list[float]): Restart weights at each
restart iteration. Default
s to
[1].
min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified.
...
...
@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
"""
def
__init__
(
self
,
periods
,
restart_weights
=
[
1
],
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
periods
:
List
[
int
]
,
restart_weights
:
List
[
float
]
=
[
1
],
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
periods
=
periods
self
.
min_lr
=
min_lr
...
...
@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))
]
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
progress
=
runner
.
epoch
else
:
...
...
@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
idx
=
get_position_from_periods
(
progress
,
self
.
cumulative_periods
)
current_weight
=
self
.
restart_weights
[
idx
]
...
...
@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
return
annealing_cos
(
base_lr
,
target_lr
,
alpha
,
current_weight
)
def
get_position_from_periods
(
iteration
,
cumulative_periods
):
def
get_position_from_periods
(
iteration
:
int
,
cumulative_periods
:
List
[
int
]
):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
...
...
@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
"""
def
__init__
(
self
,
by_epoch
=
False
,
target_ratio
=
(
10
,
1e-4
),
cyclic_times
=
1
,
step_ratio_up
=
0.4
,
anneal_strategy
=
'cos'
,
gamma
=
1
,
**
kwargs
):
by_epoch
:
bool
=
False
,
target_ratio
:
Union
[
float
,
tuple
]
=
(
10
,
1e-4
),
cyclic_times
:
int
=
1
,
step_ratio_up
:
float
=
0.4
,
anneal_strategy
:
str
=
'cos'
,
gamma
:
float
=
1
,
**
kwargs
)
->
None
:
if
isinstance
(
target_ratio
,
float
):
target_ratio
=
(
target_ratio
,
target_ratio
/
1e5
)
elif
isinstance
(
target_ratio
,
tuple
):
...
...
@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
self
.
step_ratio_up
=
step_ratio_up
self
.
gamma
=
gamma
self
.
max_iter_per_phase
=
None
self
.
lr_phases
=
[]
# init lr_phases
self
.
lr_phases
:
list
=
[]
# init lr_phases
# validate anneal_strategy
if
anneal_strategy
not
in
[
'cos'
,
'linear'
]:
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
f
'"linear", instead got
{
anneal_strategy
}
'
)
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
annealing_cos
self
.
anneal_func
:
Callable
[[
float
,
float
,
float
],
float
]
=
annealing_cos
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
annealing_linear
...
...
@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
'currently only support "by_epoch" = False'
super
().
__init__
(
by_epoch
,
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
super
().
before_run
(
runner
)
# initiate lr_phases
# total lr_phases are separated as up and down
self
.
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
iter_up_phase
=
int
(
self
.
step_ratio_up
*
self
.
max_iter_per_phase
)
iter_up_phase
=
int
(
self
.
step_ratio_up
*
self
.
max_iter_per_phase
)
# type:ignore
self
.
lr_phases
.
append
([
0
,
iter_up_phase
,
1
,
self
.
target_ratio
[
0
]])
self
.
lr_phases
.
append
([
iter_up_phase
,
self
.
max_iter_per_phase
,
self
.
target_ratio
[
0
],
self
.
target_ratio
[
1
]
])
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
curr_iter
=
runner
.
iter
%
self
.
max_iter_per_phase
curr_cycle
=
runner
.
iter
//
self
.
max_iter_per_phase
# Update weight decay
...
...
@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
"""
def
__init__
(
self
,
max_lr
,
total_steps
=
None
,
pct_start
=
0.3
,
anneal_strategy
=
'cos'
,
div_factor
=
25
,
final_div_factor
=
1e4
,
three_phase
=
False
,
**
kwargs
):
max_lr
:
Union
[
float
,
List
]
,
total_steps
:
Optional
[
int
]
=
None
,
pct_start
:
float
=
0.3
,
anneal_strategy
:
str
=
'cos'
,
div_factor
:
float
=
25
,
final_div_factor
:
float
=
1e4
,
three_phase
:
bool
=
False
,
**
kwargs
)
->
None
:
# validate by_epoch, currently only support by_epoch = False
if
'by_epoch'
not
in
kwargs
:
kwargs
[
'by_epoch'
]
=
False
...
...
@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
f
'"linear", instead got
{
anneal_strategy
}
'
)
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
annealing_cos
self
.
anneal_func
:
Callable
[[
float
,
float
,
float
],
float
]
=
annealing_cos
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
annealing_linear
self
.
div_factor
=
div_factor
self
.
final_div_factor
=
final_div_factor
self
.
three_phase
=
three_phase
self
.
lr_phases
=
[]
# init lr_phases
self
.
lr_phases
:
list
=
[]
# init lr_phases
super
().
__init__
(
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
if
hasattr
(
self
,
'total_steps'
):
total_steps
=
self
.
total_steps
else
:
...
...
@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self
.
lr_phases
.
append
(
[
total_steps
-
1
,
self
.
div_factor
,
1
/
self
.
final_div_factor
])
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
curr_iter
=
runner
.
iter
start_iter
=
0
for
i
,
(
end_iter
,
start_lr
,
end_lr
)
in
enumerate
(
self
.
lr_phases
):
...
...
@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
"""
def
__init__
(
self
,
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
):
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
...
...
@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
return
annealing_linear
(
base_lr
,
target_lr
,
progress
/
max_progress
)
def
annealing_cos
(
start
,
end
,
factor
,
weight
=
1
):
def
annealing_cos
(
start
:
float
,
end
:
float
,
factor
:
float
,
weight
:
float
=
1
)
->
float
:
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
...
...
@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1):
return
end
+
0.5
*
weight
*
(
start
-
end
)
*
cos_out
def
annealing_linear
(
start
,
end
,
factor
)
:
def
annealing_linear
(
start
:
float
,
end
:
float
,
factor
:
float
)
->
float
:
"""Calculate annealing linear learning rate.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment