Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
335199db
Unverified
Commit
335199db
authored
May 27, 2022
by
WINDSKY45
Committed by
GitHub
May 27, 2022
Browse files
Add type hint in lr_updater.py (#1988)
* [Enhance] Add type hint in `lr_updater.py`. * Fix circle import
parent
23eb359b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
74 deletions
+95
-74
mmcv/runner/hooks/lr_updater.py
mmcv/runner/hooks/lr_updater.py
+95
-74
No files found.
mmcv/runner/hooks/lr_updater.py
View file @
335199db
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
numbers
import
numbers
from
math
import
cos
,
pi
from
math
import
cos
,
pi
from
typing
import
Callable
,
List
,
Optional
,
Union
import
mmcv
import
mmcv
from
mmcv
import
runner
from
.hook
import
HOOKS
,
Hook
from
.hook
import
HOOKS
,
Hook
...
@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook):
...
@@ -23,11 +25,11 @@ class LrUpdaterHook(Hook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
warmup
=
None
,
warmup
:
Optional
[
str
]
=
None
,
warmup_iters
=
0
,
warmup_iters
:
int
=
0
,
warmup_ratio
=
0.1
,
warmup_ratio
:
float
=
0.1
,
warmup_by_epoch
=
False
)
:
warmup_by_epoch
:
bool
=
False
)
->
None
:
# validate the "warmup" argument
# validate the "warmup" argument
if
warmup
is
not
None
:
if
warmup
is
not
None
:
if
warmup
not
in
[
'constant'
,
'linear'
,
'exp'
]:
if
warmup
not
in
[
'constant'
,
'linear'
,
'exp'
]:
...
@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook):
...
@@ -42,18 +44,18 @@ class LrUpdaterHook(Hook):
self
.
by_epoch
=
by_epoch
self
.
by_epoch
=
by_epoch
self
.
warmup
=
warmup
self
.
warmup
=
warmup
self
.
warmup_iters
=
warmup_iters
self
.
warmup_iters
:
Optional
[
int
]
=
warmup_iters
self
.
warmup_ratio
=
warmup_ratio
self
.
warmup_ratio
=
warmup_ratio
self
.
warmup_by_epoch
=
warmup_by_epoch
self
.
warmup_by_epoch
=
warmup_by_epoch
if
self
.
warmup_by_epoch
:
if
self
.
warmup_by_epoch
:
self
.
warmup_epochs
=
self
.
warmup_iters
self
.
warmup_epochs
:
Optional
[
int
]
=
self
.
warmup_iters
self
.
warmup_iters
=
None
self
.
warmup_iters
=
None
else
:
else
:
self
.
warmup_epochs
=
None
self
.
warmup_epochs
=
None
self
.
base_lr
=
[]
# initial lr for all param groups
self
.
base_lr
:
Union
[
list
,
dict
]
=
[]
# initial lr for all param groups
self
.
regular_lr
=
[]
# expected lr if no warming up is performed
self
.
regular_lr
:
list
=
[]
# expected lr if no warming up is performed
def
_set_lr
(
self
,
runner
,
lr_groups
):
def
_set_lr
(
self
,
runner
,
lr_groups
):
if
isinstance
(
runner
.
optimizer
,
dict
):
if
isinstance
(
runner
.
optimizer
,
dict
):
...
@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook):
...
@@ -65,10 +67,10 @@ class LrUpdaterHook(Hook):
lr_groups
):
lr_groups
):
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
raise
NotImplementedError
raise
NotImplementedError
def
get_regular_lr
(
self
,
runner
):
def
get_regular_lr
(
self
,
runner
:
'runner.BaseRunner'
):
if
isinstance
(
runner
.
optimizer
,
dict
):
if
isinstance
(
runner
.
optimizer
,
dict
):
lr_groups
=
{}
lr_groups
=
{}
for
k
in
runner
.
optimizer
.
keys
():
for
k
in
runner
.
optimizer
.
keys
():
...
@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook):
...
@@ -82,7 +84,7 @@ class LrUpdaterHook(Hook):
else
:
else
:
return
[
self
.
get_lr
(
runner
,
_base_lr
)
for
_base_lr
in
self
.
base_lr
]
return
[
self
.
get_lr
(
runner
,
_base_lr
)
for
_base_lr
in
self
.
base_lr
]
def
get_warmup_lr
(
self
,
cur_iters
):
def
get_warmup_lr
(
self
,
cur_iters
:
int
):
def
_get_warmup_lr
(
cur_iters
,
regular_lr
):
def
_get_warmup_lr
(
cur_iters
,
regular_lr
):
if
self
.
warmup
==
'constant'
:
if
self
.
warmup
==
'constant'
:
...
@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook):
...
@@ -104,7 +106,7 @@ class LrUpdaterHook(Hook):
else
:
else
:
return
_get_warmup_lr
(
cur_iters
,
self
.
regular_lr
)
return
_get_warmup_lr
(
cur_iters
,
self
.
regular_lr
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
# it will be set according to the optimizer params
if
isinstance
(
runner
.
optimizer
,
dict
):
if
isinstance
(
runner
.
optimizer
,
dict
):
...
@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook):
...
@@ -123,10 +125,10 @@ class LrUpdaterHook(Hook):
group
[
'initial_lr'
]
for
group
in
runner
.
optimizer
.
param_groups
group
[
'initial_lr'
]
for
group
in
runner
.
optimizer
.
param_groups
]
]
def
before_train_epoch
(
self
,
runner
):
def
before_train_epoch
(
self
,
runner
:
'runner.BaseRunner'
):
if
self
.
warmup_iters
is
None
:
if
self
.
warmup_iters
is
None
:
epoch_len
=
len
(
runner
.
data_loader
)
epoch_len
=
len
(
runner
.
data_loader
)
# type: ignore
self
.
warmup_iters
=
self
.
warmup_epochs
*
epoch_len
self
.
warmup_iters
=
self
.
warmup_epochs
*
epoch_len
# type: ignore
if
not
self
.
by_epoch
:
if
not
self
.
by_epoch
:
return
return
...
@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook):
...
@@ -134,7 +136,7 @@ class LrUpdaterHook(Hook):
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
self
.
_set_lr
(
runner
,
self
.
regular_lr
)
self
.
_set_lr
(
runner
,
self
.
regular_lr
)
def
before_train_iter
(
self
,
runner
):
def
before_train_iter
(
self
,
runner
:
'runner.BaseRunner'
):
cur_iter
=
runner
.
iter
cur_iter
=
runner
.
iter
if
not
self
.
by_epoch
:
if
not
self
.
by_epoch
:
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
self
.
regular_lr
=
self
.
get_regular_lr
(
runner
)
...
@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -171,13 +173,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
step (int | list[int]): Step to decay the LR. If an int value is given,
step (int | list[int]): Step to decay the LR. If an int value is given,
regard it as the decay interval. If a list is given, decay LR at
regard it as the decay interval. If a list is given, decay LR at
these steps.
these steps.
gamma (float
, optional
): Decay LR ratio. Default
:
0.1.
gamma (float): Decay LR ratio. Default
s to
0.1.
min_lr (float, optional): Minimum LR value to keep. If LR after decay
min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value. If None
is lower than `min_lr`, it will be clipped to this value. If None
is given, we don't perform lr clipping. Default: None.
is given, we don't perform lr clipping. Default: None.
"""
"""
def
__init__
(
self
,
step
,
gamma
=
0.1
,
min_lr
=
None
,
**
kwargs
):
def
__init__
(
self
,
step
:
Union
[
int
,
List
[
int
]],
gamma
:
float
=
0.1
,
min_lr
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
if
isinstance
(
step
,
list
):
if
isinstance
(
step
,
list
):
assert
mmcv
.
is_list_of
(
step
,
int
)
assert
mmcv
.
is_list_of
(
step
,
int
)
assert
all
([
s
>
0
for
s
in
step
])
assert
all
([
s
>
0
for
s
in
step
])
...
@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -190,7 +196,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
# calculate exponential term
# calculate exponential term
...
@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -213,11 +219,11 @@ class StepLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
@
HOOKS
.
register_module
()
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
def
__init__
(
self
,
gamma
:
float
,
**
kwargs
)
->
None
:
self
.
gamma
=
gamma
self
.
gamma
=
gamma
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
return
base_lr
*
self
.
gamma
**
progress
return
base_lr
*
self
.
gamma
**
progress
...
@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook):
...
@@ -225,12 +231,15 @@ class ExpLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
@
HOOKS
.
register_module
()
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
def
__init__
(
self
,
power
:
float
=
1.
,
min_lr
:
float
=
0.
,
**
kwargs
)
->
None
:
self
.
power
=
power
self
.
power
=
power
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
progress
=
runner
.
epoch
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
max_progress
=
runner
.
max_epochs
...
@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook):
...
@@ -244,12 +253,12 @@ class PolyLrUpdaterHook(LrUpdaterHook):
@
HOOKS
.
register_module
()
@
HOOKS
.
register_module
()
class
InvLrUpdaterHook
(
LrUpdaterHook
):
class
InvLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
def
__init__
(
self
,
gamma
:
float
,
power
:
float
=
1.
,
**
kwargs
)
->
None
:
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
power
=
power
self
.
power
=
power
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
progress
=
runner
.
epoch
if
self
.
by_epoch
else
runner
.
iter
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
...
@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -265,13 +274,16 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
Default: None.
"""
"""
def
__init__
(
self
,
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
progress
=
runner
.
epoch
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
max_progress
=
runner
.
max_epochs
...
@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -282,7 +294,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
return
annealing_cos
(
base_lr
,
target_lr
,
progress
/
max_progress
)
return
annealing_cos
(
base_lr
,
target_lr
,
progress
/
max_progress
)
...
@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -304,10 +316,10 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
start_percent
=
0.75
,
start_percent
:
float
=
0.75
,
min_lr
=
None
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
):
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
if
start_percent
<
0
or
start_percent
>
1
or
not
isinstance
(
if
start_percent
<
0
or
start_percent
>
1
or
not
isinstance
(
start_percent
,
float
):
start_percent
,
float
):
...
@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -319,7 +331,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
start
=
round
(
runner
.
max_epochs
*
self
.
start_percent
)
start
=
round
(
runner
.
max_epochs
*
self
.
start_percent
)
progress
=
runner
.
epoch
-
start
progress
=
runner
.
epoch
-
start
...
@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -332,7 +344,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
if
progress
<
0
:
if
progress
<
0
:
return
base_lr
return
base_lr
...
@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -346,8 +358,8 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
Args:
Args:
periods (list[int]): Periods for each cosine anneling cycle.
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float]
, optional
): Restart weights at each
restart_weights (list[float]): Restart weights at each
restart iteration. Default
:
[1].
restart iteration. Default
s to
[1].
min_lr (float, optional): The minimum lr. Default: None.
min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified.
Either `min_lr` or `min_lr_ratio` should be specified.
...
@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -355,11 +367,11 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
periods
,
periods
:
List
[
int
]
,
restart_weights
=
[
1
],
restart_weights
:
List
[
float
]
=
[
1
],
min_lr
=
None
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
):
**
kwargs
)
->
None
:
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
periods
=
periods
self
.
periods
=
periods
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
...
@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -373,7 +385,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))
]
]
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
progress
=
runner
.
epoch
progress
=
runner
.
epoch
else
:
else
:
...
@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -382,7 +394,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
idx
=
get_position_from_periods
(
progress
,
self
.
cumulative_periods
)
idx
=
get_position_from_periods
(
progress
,
self
.
cumulative_periods
)
current_weight
=
self
.
restart_weights
[
idx
]
current_weight
=
self
.
restart_weights
[
idx
]
...
@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
...
@@ -393,7 +405,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
return
annealing_cos
(
base_lr
,
target_lr
,
alpha
,
current_weight
)
return
annealing_cos
(
base_lr
,
target_lr
,
alpha
,
current_weight
)
def
get_position_from_periods
(
iteration
,
cumulative_periods
):
def
get_position_from_periods
(
iteration
:
int
,
cumulative_periods
:
List
[
int
]
):
"""Get the position from a period list.
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
It will return the index of the right-closest number in the period list.
...
@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
...
@@ -444,13 +456,13 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
by_epoch
=
False
,
by_epoch
:
bool
=
False
,
target_ratio
=
(
10
,
1e-4
),
target_ratio
:
Union
[
float
,
tuple
]
=
(
10
,
1e-4
),
cyclic_times
=
1
,
cyclic_times
:
int
=
1
,
step_ratio_up
=
0.4
,
step_ratio_up
:
float
=
0.4
,
anneal_strategy
=
'cos'
,
anneal_strategy
:
str
=
'cos'
,
gamma
=
1
,
gamma
:
float
=
1
,
**
kwargs
):
**
kwargs
)
->
None
:
if
isinstance
(
target_ratio
,
float
):
if
isinstance
(
target_ratio
,
float
):
target_ratio
=
(
target_ratio
,
target_ratio
/
1e5
)
target_ratio
=
(
target_ratio
,
target_ratio
/
1e5
)
elif
isinstance
(
target_ratio
,
tuple
):
elif
isinstance
(
target_ratio
,
tuple
):
...
@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
...
@@ -472,13 +484,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
self
.
step_ratio_up
=
step_ratio_up
self
.
step_ratio_up
=
step_ratio_up
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
max_iter_per_phase
=
None
self
.
max_iter_per_phase
=
None
self
.
lr_phases
=
[]
# init lr_phases
self
.
lr_phases
:
list
=
[]
# init lr_phases
# validate anneal_strategy
# validate anneal_strategy
if
anneal_strategy
not
in
[
'cos'
,
'linear'
]:
if
anneal_strategy
not
in
[
'cos'
,
'linear'
]:
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
f
'"linear", instead got
{
anneal_strategy
}
'
)
f
'"linear", instead got
{
anneal_strategy
}
'
)
elif
anneal_strategy
==
'cos'
:
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
annealing_cos
self
.
anneal_func
:
Callable
[[
float
,
float
,
float
],
float
]
=
annealing_cos
elif
anneal_strategy
==
'linear'
:
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
annealing_linear
self
.
anneal_func
=
annealing_linear
...
@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
...
@@ -486,19 +499,20 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
'currently only support "by_epoch" = False'
'currently only support "by_epoch" = False'
super
().
__init__
(
by_epoch
,
**
kwargs
)
super
().
__init__
(
by_epoch
,
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
super
().
before_run
(
runner
)
super
().
before_run
(
runner
)
# initiate lr_phases
# initiate lr_phases
# total lr_phases are separated as up and down
# total lr_phases are separated as up and down
self
.
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
self
.
max_iter_per_phase
=
runner
.
max_iters
//
self
.
cyclic_times
iter_up_phase
=
int
(
self
.
step_ratio_up
*
self
.
max_iter_per_phase
)
iter_up_phase
=
int
(
self
.
step_ratio_up
*
self
.
max_iter_per_phase
)
# type:ignore
self
.
lr_phases
.
append
([
0
,
iter_up_phase
,
1
,
self
.
target_ratio
[
0
]])
self
.
lr_phases
.
append
([
0
,
iter_up_phase
,
1
,
self
.
target_ratio
[
0
]])
self
.
lr_phases
.
append
([
self
.
lr_phases
.
append
([
iter_up_phase
,
self
.
max_iter_per_phase
,
self
.
target_ratio
[
0
],
iter_up_phase
,
self
.
max_iter_per_phase
,
self
.
target_ratio
[
0
],
self
.
target_ratio
[
1
]
self
.
target_ratio
[
1
]
])
])
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
curr_iter
=
runner
.
iter
%
self
.
max_iter_per_phase
curr_iter
=
runner
.
iter
%
self
.
max_iter_per_phase
curr_cycle
=
runner
.
iter
//
self
.
max_iter_per_phase
curr_cycle
=
runner
.
iter
//
self
.
max_iter_per_phase
# Update weight decay
# Update weight decay
...
@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
...
@@ -558,14 +572,14 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
max_lr
,
max_lr
:
Union
[
float
,
List
]
,
total_steps
=
None
,
total_steps
:
Optional
[
int
]
=
None
,
pct_start
=
0.3
,
pct_start
:
float
=
0.3
,
anneal_strategy
=
'cos'
,
anneal_strategy
:
str
=
'cos'
,
div_factor
=
25
,
div_factor
:
float
=
25
,
final_div_factor
=
1e4
,
final_div_factor
:
float
=
1e4
,
three_phase
=
False
,
three_phase
:
bool
=
False
,
**
kwargs
):
**
kwargs
)
->
None
:
# validate by_epoch, currently only support by_epoch = False
# validate by_epoch, currently only support by_epoch = False
if
'by_epoch'
not
in
kwargs
:
if
'by_epoch'
not
in
kwargs
:
kwargs
[
'by_epoch'
]
=
False
kwargs
[
'by_epoch'
]
=
False
...
@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
...
@@ -591,16 +605,17 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
raise
ValueError
(
'anneal_strategy must be one of "cos" or '
f
'"linear", instead got
{
anneal_strategy
}
'
)
f
'"linear", instead got
{
anneal_strategy
}
'
)
elif
anneal_strategy
==
'cos'
:
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
annealing_cos
self
.
anneal_func
:
Callable
[[
float
,
float
,
float
],
float
]
=
annealing_cos
elif
anneal_strategy
==
'linear'
:
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
annealing_linear
self
.
anneal_func
=
annealing_linear
self
.
div_factor
=
div_factor
self
.
div_factor
=
div_factor
self
.
final_div_factor
=
final_div_factor
self
.
final_div_factor
=
final_div_factor
self
.
three_phase
=
three_phase
self
.
three_phase
=
three_phase
self
.
lr_phases
=
[]
# init lr_phases
self
.
lr_phases
:
list
=
[]
# init lr_phases
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
:
'runner.BaseRunner'
):
if
hasattr
(
self
,
'total_steps'
):
if
hasattr
(
self
,
'total_steps'
):
total_steps
=
self
.
total_steps
total_steps
=
self
.
total_steps
else
:
else
:
...
@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
...
@@ -639,7 +654,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self
.
lr_phases
.
append
(
self
.
lr_phases
.
append
(
[
total_steps
-
1
,
self
.
div_factor
,
1
/
self
.
final_div_factor
])
[
total_steps
-
1
,
self
.
div_factor
,
1
/
self
.
final_div_factor
])
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
curr_iter
=
runner
.
iter
curr_iter
=
runner
.
iter
start_iter
=
0
start_iter
=
0
for
i
,
(
end_iter
,
start_lr
,
end_lr
)
in
enumerate
(
self
.
lr_phases
):
for
i
,
(
end_iter
,
start_lr
,
end_lr
)
in
enumerate
(
self
.
lr_phases
):
...
@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -664,13 +679,16 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
Default: None.
Default: None.
"""
"""
def
__init__
(
self
,
min_lr
=
None
,
min_lr_ratio
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_lr
:
Optional
[
float
]
=
None
,
min_lr_ratio
:
Optional
[
float
]
=
None
,
**
kwargs
):
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
min_lr_ratio
=
min_lr_ratio
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
def
get_lr
(
self
,
runner
:
'runner.BaseRunner'
,
base_lr
:
float
):
if
self
.
by_epoch
:
if
self
.
by_epoch
:
progress
=
runner
.
epoch
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
max_progress
=
runner
.
max_epochs
...
@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
...
@@ -680,11 +698,14 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
if
self
.
min_lr_ratio
is
not
None
:
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
else
:
target_lr
=
self
.
min_lr
target_lr
=
self
.
min_lr
# type:ignore
return
annealing_linear
(
base_lr
,
target_lr
,
progress
/
max_progress
)
return
annealing_linear
(
base_lr
,
target_lr
,
progress
/
max_progress
)
def
annealing_cos
(
start
,
end
,
factor
,
weight
=
1
):
def
annealing_cos
(
start
:
float
,
end
:
float
,
factor
:
float
,
weight
:
float
=
1
)
->
float
:
"""Calculate annealing cos learning rate.
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
...
@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1):
...
@@ -702,7 +723,7 @@ def annealing_cos(start, end, factor, weight=1):
return
end
+
0.5
*
weight
*
(
start
-
end
)
*
cos_out
return
end
+
0.5
*
weight
*
(
start
-
end
)
*
cos_out
def
annealing_linear
(
start
,
end
,
factor
)
:
def
annealing_linear
(
start
:
float
,
end
:
float
,
factor
:
float
)
->
float
:
"""Calculate annealing linear learning rate.
"""Calculate annealing linear learning rate.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment