Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
04537bf8
Unverified
Commit
04537bf8
authored
Jul 07, 2022
by
Yi Zhao
Committed by
GitHub
Jul 07, 2022
Browse files
[checkpoint]support generalized scheduler (#1222)
parent
a98319f0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
20 deletions
+85
-20
colossalai/nn/lr_scheduler/delayed.py
colossalai/nn/lr_scheduler/delayed.py
+31
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+0
-1
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+21
-8
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+33
-11
No files found.
colossalai/nn/lr_scheduler/delayed.py
View file @
04537bf8
...
...
@@ -2,6 +2,7 @@ from torch.optim.lr_scheduler import _LRScheduler
class
_enable_get_lr_call
:
def
__init__
(
self
,
o
):
self
.
o
=
o
...
...
@@ -33,6 +34,16 @@ class DelayerScheduler(_LRScheduler):
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
'optimizer'
}
if
isinstance
(
state_dict
[
'after_scheduler'
],
_LRScheduler
):
state_dict
[
'after_scheduler_type'
]
=
type
(
state_dict
[
'after_scheduler'
]).
__name__
state_dict
[
'after_scheduler_dict'
]
=
state_dict
[
'after_scheduler'
].
state_dict
()
del
state_dict
[
'after_scheduler'
]
else
:
raise
NotImplementedError
()
return
state_dict
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
delay_epochs
:
if
not
self
.
finished
:
...
...
@@ -73,6 +84,16 @@ class WarmupScheduler(_LRScheduler):
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
'optimizer'
}
if
isinstance
(
state_dict
[
'after_scheduler'
],
_LRScheduler
):
state_dict
[
'after_scheduler_type'
]
=
type
(
state_dict
[
'after_scheduler'
]).
__name__
state_dict
[
'after_scheduler_dict'
]
=
state_dict
[
'after_scheduler'
].
state_dict
()
del
state_dict
[
'after_scheduler'
]
else
:
raise
NotImplementedError
()
return
state_dict
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
warmup_epochs
:
if
not
self
.
finished
:
...
...
@@ -118,6 +139,16 @@ class WarmupDelayerScheduler(_LRScheduler):
self
.
finished
=
False
super
().
__init__
(
optimizer
,
last_epoch
)
def
state_dict
(
self
):
state_dict
=
{
key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
not
in
'optimizer'
}
if
isinstance
(
state_dict
[
'after_scheduler'
],
_LRScheduler
):
state_dict
[
'after_scheduler_type'
]
=
type
(
state_dict
[
'after_scheduler'
]).
__name__
state_dict
[
'after_scheduler_dict'
]
=
state_dict
[
'after_scheduler'
].
state_dict
()
del
state_dict
[
'after_scheduler'
]
else
:
raise
NotImplementedError
()
return
state_dict
def
get_lr
(
self
):
if
self
.
last_epoch
>=
self
.
warmup_epochs
+
self
.
delay_epochs
:
if
not
self
.
finished
:
...
...
colossalai/tensor/colo_tensor.py
View file @
04537bf8
...
...
@@ -29,7 +29,6 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
pg
=
_scan_for_pg_from_args
(
elem
,
{})
if
pg
is
not
None
:
return
pg
print
(
type
(
elem
),
elem
,
isinstance
(
elem
,
(
list
,
tuple
)))
for
k
,
v
in
kwargs
:
if
isinstance
(
v
,
ColoTensor
):
pg
=
v
.
get_process_group
()
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
04537bf8
...
...
@@ -2,10 +2,20 @@ import torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
collections
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
import
inspect
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
def
filter_dict
(
dict_to_filter
,
thing_with_kwargs
):
sig
=
inspect
.
signature
(
thing_with_kwargs
)
filter_keys
=
[
param
.
name
for
param
in
sig
.
parameters
.
values
()
if
param
.
kind
==
param
.
POSITIONAL_OR_KEYWORD
]
filter_dict
=
{}
for
filter_key
in
filter_keys
:
if
filter_key
in
dict_to_filter
:
filter_dict
[
filter_key
]
=
dict_to_filter
[
filter_key
]
return
filter_dict
def
save_checkpoint
(
dire
:
str
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
...
...
@@ -25,9 +35,7 @@ def save_checkpoint(dire: str,
model_state
=
{
'epoch'
:
epoch
,
'model'
:
colo_state_dict
(
model
,
state_dict_func
=
nn
.
Module
.
state_dict
)}
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
lr_scheduler_dict
=
lr_scheduler
.
state_dict
()
lr_scheduler_dict
[
'after_scheduler'
]
=
lr_scheduler_dict
[
'after_scheduler'
].
state_dict
()
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler_dict
}
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
()}
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
...
...
@@ -55,8 +63,13 @@ def load_checkpoint(dire,
optim_state
=
torch
.
load
(
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
rank
))
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
after_scheduler_dict
=
lr_scheduler_dict
[
'after_scheduler'
]
lr_scheduler_dict
[
'after_scheduler'
]
=
_CosineAnnealingLR
(
optimizer
,
after_scheduler_dict
[
'T_max'
],
after_scheduler_dict
[
'eta_min'
],
after_scheduler_dict
[
'last_epoch'
])
if
'after_scheduler_type'
in
lr_scheduler_dict
:
after_scheduler_type
=
lr_scheduler_dict
.
pop
(
'after_scheduler_type'
)
after_scheduler_dict
=
lr_scheduler_dict
.
pop
(
'after_scheduler_dict'
)
reload_scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
after_scheduler_type
)
filtered_dict
=
filter_dict
(
after_scheduler_dict
,
reload_scheduler
)
lr_scheduler_dict
[
'after_scheduler'
]
=
reload_scheduler
(
optimizer
,
**
filtered_dict
,
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
tests/test_utils/test_colo_checkpoint.py
View file @
04537bf8
...
...
@@ -8,6 +8,8 @@ from functools import partial
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torch.optim.lr_scheduler
import
MultiplicativeLR
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
...
...
@@ -102,10 +104,14 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
pg
):
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
):
num_epoch
=
5
warmup_epoch
=
2
batch
=
3
feature
=
32
category
=
16
train_dataloader
=
DummyDataLoader
(
batch
,
category
,
feature
,
length
=
16
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
feature
,
category
)
...
...
@@ -129,14 +135,25 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
20
,
warmup_steps
=
5
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
20
,
warmup_steps
=
5
)
lr_scheduler_ref
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_ref
,
total_steps
=
20
,
warmup_steps
=
5
)
if
test_scheduler
==
'colossalai_cosine_warmup'
:
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
elif
test_scheduler
==
'torch_cosine'
:
lr_scheduler
=
CosineAnnealingLR
(
optimizer
=
optimizer
,
T_max
=
num_epoch
)
lr_scheduler_reload
=
CosineAnnealingLR
(
optimizer
=
optimizer_reload
,
T_max
=
num_epoch
)
elif
test_scheduler
==
'torch_lambda'
:
lr_lambda
=
lambda
epoch
:
0.95
lr_scheduler
=
MultiplicativeLR
(
optimizer
=
optimizer
,
lr_lambda
=
lr_lambda
)
lr_scheduler_reload
=
MultiplicativeLR
(
optimizer
=
optimizer_reload
,
lr_lambda
=
lr_lambda
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model_ref
,
pg
)
for
epoch
in
range
(
0
,
20
):
for
epoch
in
range
(
0
,
num_epoch
):
if
epoch
<=
test_epoch
:
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
...
...
@@ -155,7 +172,6 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
for
ref_p
,
p
in
zip
(
model_ref
.
parameters
(),
model
.
parameters
()):
ref_p
.
data
.
copy_
(
p
)
optimizer_ref
=
copy
.
deepcopy
(
optimizer
)
lr_scheduler_ref
=
copy
.
deepcopy
(
lr_scheduler
)
check_param_equal
(
model
,
model_ref
)
save_checkpoint
(
'./checkpoint'
,
epoch
,
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -189,28 +205,34 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
check_param_equal
(
model_ref
,
model_reload
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
test_epoch
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
test_epoch
,
test_scheduler
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
,
pg
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'test_epoch'
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
test_epoch
):
def
test_checkpoint
(
world_size
,
use_ddp
,
test_epoch
,
test_scheduler
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
test_epoch
=
test_epoch
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
test_epoch
=
test_epoch
,
test_scheduler
=
test_scheduler
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
test_checkpoint
(
4
,
True
,
1
)
test_checkpoint
(
4
,
True
,
1
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment