Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
20da6e48
Unverified
Commit
20da6e48
authored
Jul 08, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 08, 2022
Browse files
[checkpoint] save sharded optimizer states (#1237)
parent
4a76084d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
19 deletions
+28
-19
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+10
-11
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+6
-1
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+12
-7
No files found.
colossalai/tensor/process_group.py
View file @
20da6e48
...
@@ -93,20 +93,17 @@ class ProcessGroup:
...
@@ -93,20 +93,17 @@ class ProcessGroup:
if
idx
//
self
.
_tp_degree
==
self
.
_rank_idx
//
self
.
_tp_degree
:
if
idx
//
self
.
_tp_degree
==
self
.
_rank_idx
//
self
.
_tp_degree
:
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
self
.
_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
self
.
_has_cpu_groups
=
False
self
.
_has_cpu_groups
=
False
self
.
_cpu_dp_process_group
=
None
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
self
.
_cpu_tp_process_group
=
None
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
def
set_cpu_groups
(
self
):
def
set_cpu_groups
(
self
):
if
self
.
has_cpu_groups
:
if
self
.
has_cpu_groups
:
return
return
self
.
logger
.
info
(
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
self
.
_cpu_tp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
self
.
_cpu_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
@
property
@
property
def
has_cpu_groups
(
self
):
def
has_cpu_groups
(
self
):
...
@@ -152,13 +149,15 @@ class ProcessGroup:
...
@@ -152,13 +149,15 @@ class ProcessGroup:
return
len
(
self
.
_tp_rank_list
)
return
len
(
self
.
_tp_rank_list
)
def
dp_process_group
(
self
):
def
dp_process_group
(
self
):
return
self
.
_dp_process_group
# return self._dp_process_group
return
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
def
tp_process_group
(
self
):
def
tp_process_group
(
self
):
return
self
.
_tp_process_group
# return self._tp_process_group
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
def
cpu_dp_process_group
(
self
):
def
cpu_dp_process_group
(
self
):
return
self
.
_cpu_dp_process_group
return
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
def
cpu_tp_process_group
(
self
):
def
cpu_tp_process_group
(
self
):
return
self
.
_cpu_tp_process_group
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
colossalai/utils/checkpoint/module_checkpoint.py
View file @
20da6e48
...
@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
...
@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
"""
model_state
=
{
'epoch'
:
epoch
,
'model'
:
colo_state_dict
(
model
,
state_dict_func
=
nn
.
Module
.
state_dict
)}
model_state
=
{
'epoch'
:
epoch
,
'model'
:
model
.
state_dict
(
)}
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
# 1. convert SHARD ColoTensor to REPLICATE
# only rank 0 saves the REPLICATE tensors.
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
()}
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
()}
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
20da6e48
...
@@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
...
@@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_ref
=
ColoDDP
(
model_ref
,
pg
)
model_ref
=
ColoDDP
(
model_ref
,
pg
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model_ref
,
pg
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
...
@@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
...
@@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
weight_decay
=
0
)
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
None
if
test_scheduler
==
'colossalai_cosine_warmup'
:
if
test_scheduler
==
'colossalai_cosine_warmup'
:
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
num_epoch
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
warmup_steps
=
warmup_epoch
)
elif
test_scheduler
==
'torch_cosine'
:
elif
test_scheduler
==
'torch_cosine'
:
lr_scheduler
=
CosineAnnealingLR
(
optimizer
=
optimizer
,
T_max
=
num_epoch
)
lr_scheduler
=
CosineAnnealingLR
(
optimizer
=
optimizer
,
T_max
=
num_epoch
)
lr_scheduler_reload
=
CosineAnnealingLR
(
optimizer
=
optimizer_reload
,
T_max
=
num_epoch
)
lr_scheduler_reload
=
CosineAnnealingLR
(
optimizer
=
optimizer_reload
,
T_max
=
num_epoch
)
elif
test_scheduler
==
'torch_lambda'
:
elif
test_scheduler
==
'torch_lambda'
:
lr_lambda
=
lambda
epoch
:
0.95
lr_lambda
=
lambda
epoch
:
0.95
lr_scheduler
=
MultiplicativeLR
(
optimizer
=
optimizer
,
lr_lambda
=
lr_lambda
)
lr_scheduler
=
MultiplicativeLR
(
optimizer
=
optimizer
,
lr_lambda
=
lr_lambda
)
lr_scheduler_reload
=
MultiplicativeLR
(
optimizer
=
optimizer_reload
,
lr_lambda
=
lr_lambda
)
lr_scheduler_reload
=
MultiplicativeLR
(
optimizer
=
optimizer_reload
,
lr_lambda
=
lr_lambda
)
else
:
init_spec_func
(
model
,
pg
)
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
init_spec_func
(
model_ref
,
pg
)
for
epoch
in
range
(
0
,
num_epoch
):
for
epoch
in
range
(
0
,
num_epoch
):
if
epoch
<=
test_epoch
:
if
epoch
<=
test_epoch
:
...
@@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
...
@@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
=
test_epoch
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
skip
...
@@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
...
@@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_checkpoint
(
4
,
True
,
1
,
1
)
test_checkpoint
(
4
,
True
,
1
,
"colossalai_cosine_warmup"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment