Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
39f2582e
Unverified
Commit
39f2582e
authored
Oct 12, 2023
by
Baizhou Zhang
Committed by
GitHub
Oct 12, 2023
Browse files
[hotfix] fix lr scheduler bug in torch 2.0 (#4864)
parent
83b52c56
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
colossalai/nn/lr_scheduler/delayed.py
colossalai/nn/lr_scheduler/delayed.py
+7
-1
tests/test_checkpoint_io/test_general_checkpoint_io.py
tests/test_checkpoint_io/test_general_checkpoint_io.py
+9
-2
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+1
-0
No files found.
colossalai/nn/lr_scheduler/delayed.py
View file @
39f2582e
from
torch.optim.lr_scheduler
import
_LRScheduler
import
torch
from
packaging.version
import
Version
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.0.0"
):
from
torch.optim.lr_scheduler
import
LRScheduler
as
_LRScheduler
else
:
from
torch.optim.lr_scheduler
import
_LRScheduler
class
_enable_get_lr_call
:
...
...
tests/test_checkpoint_io/test_general_checkpoint_io.py
View file @
39f2582e
...
...
@@ -6,6 +6,7 @@ from torch.optim import Adam
from
torchvision.models
import
resnet18
from
colossalai.checkpoint_io
import
GeneralCheckpointIO
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.testing
import
check_state_dict_equal
,
clear_cache_before_run
,
parameterize
# ========
...
...
@@ -22,6 +23,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
,
total_steps
=
10
)
# create test data sample
x
=
torch
.
randn
(
1
,
3
,
224
,
224
)
...
...
@@ -31,6 +33,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
loss
=
y
.
sum
()
loss
.
backward
()
optimizer
.
step
()
lr_scheduler
.
step
()
# create a temp file for checkpoint
if
use_safetensors
:
...
...
@@ -39,19 +42,23 @@ def test_unsharded_checkpoint(use_safetensors: bool):
suffix
=
".bin"
model_ckpt_tempfile
=
tempfile
.
NamedTemporaryFile
(
suffix
=
suffix
)
optimizer_ckpt_tempfile
=
tempfile
.
NamedTemporaryFile
()
lr_scheduler_ckpt_tempfile
=
tempfile
.
NamedTemporaryFile
()
# save the model
and
optimizer
# save the model
,
optimizer
, lr_scheduler
ckpt_io
=
GeneralCheckpointIO
()
ckpt_io
.
save_model
(
model
,
model_ckpt_tempfile
.
name
,
use_safetensors
=
use_safetensors
)
ckpt_io
.
save_optimizer
(
optimizer
,
optimizer_ckpt_tempfile
.
name
)
ckpt_io
.
save_lr_scheduler
(
lr_scheduler
,
lr_scheduler_ckpt_tempfile
.
name
)
# create new model
new_model
=
resnet18
()
new_optimizer
=
Adam
(
new_model
.
parameters
(),
lr
=
0.001
)
new_lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
,
total_steps
=
10
)
# load the model
and
optimizer
# load the model
,
optimizer
, lr_scheduler
ckpt_io
.
load_model
(
new_model
,
model_ckpt_tempfile
.
name
)
ckpt_io
.
load_optimizer
(
new_optimizer
,
optimizer_ckpt_tempfile
.
name
)
ckpt_io
.
load_lr_scheduler
(
new_lr_scheduler
,
lr_scheduler_ckpt_tempfile
.
name
)
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
...
...
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
View file @
39f2582e
...
...
@@ -72,6 +72,7 @@ def run_dist(rank, world_size, port):
exam_zero_optim_state_dict
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment