Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c53ddda8
Unverified
Commit
c53ddda8
authored
Feb 06, 2024
by
Hongxin Liu
Committed by
GitHub
Feb 06, 2024
Browse files
[lr-scheduler] fix load state dict and add test (#5369)
parent
2dd01e3a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
42 deletions
+60
-42
colossalai/nn/lr_scheduler/delayed.py
colossalai/nn/lr_scheduler/delayed.py
+40
-42
tests/test_optimizer/test_lr_scheduler.py
tests/test_optimizer/test_lr_scheduler.py
+20
-0
No files found.
colossalai/nn/lr_scheduler/delayed.py
View file @
c53ddda8
...
...
@@ -6,6 +6,8 @@ if Version(torch.__version__) >= Version("2.0.0"):
else
:
from
torch.optim.lr_scheduler
import
_LRScheduler
from
colossalai.logging
import
get_dist_logger
class
_enable_get_lr_call
:
def
__init__
(
self
,
o
):
...
...
@@ -19,7 +21,39 @@ class _enable_get_lr_call:
self
.
o
.
_get_lr_called_within_step
=
False
class
DelayerScheduler
(
_LRScheduler
):
class
TwoStageScheduler
(
_LRScheduler
):
def
__init__
(
self
,
optimizer
,
after_scheduler
:
_LRScheduler
,
last_epoch
=-
1
):
self
.
after_scheduler
=
after_scheduler
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
"optimizer"
}
if
isinstance
(
state_dict
[
"after_scheduler"
],
_LRScheduler
):
state_dict
[
"after_scheduler_type"
]
=
type
(
state_dict
[
"after_scheduler"
]).
__name__
state_dict
[
"after_scheduler_dict"
]
=
state_dict
[
"after_scheduler"
].
state_dict
()
del
state_dict
[
"after_scheduler"
]
else
:
raise
NotImplementedError
()
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
if
"after_scheduler_dict"
not
in
state_dict
:
logger
=
get_dist_logger
()
logger
.
warning
(
"after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior."
)
else
:
self
.
after_scheduler
.
load_state_dict
(
state_dict
[
"after_scheduler_dict"
])
state_dict
=
{
key
:
value
for
key
,
value
in
state_dict
.
items
()
if
key
not
in
(
"after_scheduler_type"
,
"after_scheduler_dict"
)
}
super
().
load_state_dict
(
state_dict
)
class
DelayerScheduler
(
TwoStageScheduler
):
"""Starts with a flat lr schedule until it reaches N epochs then applies
the specific scheduler (For example: ReduceLROnPlateau)
...
...
@@ -35,19 +69,7 @@ class DelayerScheduler(_LRScheduler):
if
delay_epochs
<
0
:
raise
ValueError
(
f
"delay_epochs must >= 0, got
{
delay_epochs
}
"
)
self
.
delay_epochs
=
delay_epochs
self
.
after_scheduler
=
after_scheduler
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
"optimizer"
}
if
isinstance
(
state_dict
[
"after_scheduler"
],
_LRScheduler
):
state_dict
[
"after_scheduler_type"
]
=
type
(
state_dict
[
"after_scheduler"
]).
__name__
state_dict
[
"after_scheduler_dict"
]
=
state_dict
[
"after_scheduler"
].
state_dict
()
del
state_dict
[
"after_scheduler"
]
else
:
raise
NotImplementedError
()
return
state_dict
super
().
__init__
(
optimizer
,
after_scheduler
,
last_epoch
)
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
delay_epochs
:
...
...
@@ -71,7 +93,7 @@ class DelayerScheduler(_LRScheduler):
return
super
(
DelayerScheduler
,
self
).
step
(
epoch
)
class
WarmupScheduler
(
_LR
Scheduler
):
class
WarmupScheduler
(
TwoStage
Scheduler
):
"""Starts with a linear warmup lr schedule until it reaches N epochs then applies
the specific scheduler (For example: ReduceLROnPlateau).
...
...
@@ -85,19 +107,7 @@ class WarmupScheduler(_LRScheduler):
def
__init__
(
self
,
optimizer
,
warmup_epochs
,
after_scheduler
,
last_epoch
=-
1
):
self
.
warmup_epochs
=
int
(
warmup_epochs
)
self
.
after_scheduler
=
after_scheduler
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
"optimizer"
}
if
isinstance
(
state_dict
[
"after_scheduler"
],
_LRScheduler
):
state_dict
[
"after_scheduler_type"
]
=
type
(
state_dict
[
"after_scheduler"
]).
__name__
state_dict
[
"after_scheduler_dict"
]
=
state_dict
[
"after_scheduler"
].
state_dict
()
del
state_dict
[
"after_scheduler"
]
else
:
raise
NotImplementedError
()
return
state_dict
super
().
__init__
(
optimizer
,
after_scheduler
,
last_epoch
)
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
warmup_epochs
:
...
...
@@ -120,7 +130,7 @@ class WarmupScheduler(_LRScheduler):
return
super
().
step
(
epoch
)
class
WarmupDelayerScheduler
(
_LR
Scheduler
):
class
WarmupDelayerScheduler
(
TwoStage
Scheduler
):
"""Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule
until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau).
...
...
@@ -140,19 +150,7 @@ class WarmupDelayerScheduler(_LRScheduler):
raise
ValueError
(
f
"warmup_epochs must >= 0, got
{
warmup_epochs
}
"
)
self
.
warmup_epochs
=
warmup_epochs
self
.
delay_epochs
=
delay_epochs
self
.
after_scheduler
=
after_scheduler
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
"optimizer"
}
if
isinstance
(
state_dict
[
"after_scheduler"
],
_LRScheduler
):
state_dict
[
"after_scheduler_type"
]
=
type
(
state_dict
[
"after_scheduler"
]).
__name__
state_dict
[
"after_scheduler_dict"
]
=
state_dict
[
"after_scheduler"
].
state_dict
()
del
state_dict
[
"after_scheduler"
]
else
:
raise
NotImplementedError
()
return
state_dict
super
().
__init__
(
optimizer
,
after_scheduler
,
last_epoch
)
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
warmup_epochs
+
self
.
delay_epochs
:
...
...
tests/test_optimizer/test_lr_scheduler.py
0 → 100644
View file @
c53ddda8
import
torch.nn
as
nn
from
torch.optim
import
Adam
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
def
test_lr_scheduler_save_load
():
model
=
nn
.
Linear
(
10
,
10
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
scheduler
=
CosineAnnealingWarmupLR
(
optimizer
,
total_steps
=
5
,
warmup_steps
=
2
)
new_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
,
total_steps
=
5
,
warmup_steps
=
2
)
for
_
in
range
(
5
):
scheduler
.
step
()
state_dict
=
scheduler
.
state_dict
()
new_scheduler
.
load_state_dict
(
state_dict
)
assert
state_dict
==
new_scheduler
.
state_dict
()
if
__name__
==
"__main__"
:
test_lr_scheduler_save_load
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment