Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e4c6449
Unverified
Commit
9e4c6449
authored
Jul 15, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 15, 2022
Browse files
[checkpoint] add ColoOptimizer checkpointing (#1316)
parent
7c2634f4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
15 deletions
+74
-15
colossalai/nn/optimizer/colossalai_optimizer.py
colossalai/nn/optimizer/colossalai_optimizer.py
+0
-3
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+38
-4
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+36
-8
No files found.
colossalai/nn/optimizer/colossalai_optimizer.py
View file @
9e4c6449
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
9e4c6449
import
torch
import
torch.distributed
as
dist
from
colossalai.tensor
import
ColoTensor
,
DistSpecManager
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
copy
import
copy
from
typing
import
Optional
def
save_checkpoint
(
dire
:
str
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
optimizer
:
Optional
[
Colossalai
Optimizer
]
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
dire (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (
torch.optim.
Optimizer, optional): optimizers. Defaults to None.
optimizer (
Colossalai
Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
...
...
@@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
# delete the new dict
del
new_dict
optim_state_copy
=
copy
(
optimizer
.
state_dict
())
for
k
,
v
in
optim_state_copy
[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
t
.
to_replicate_
()
if
dist
.
get_rank
()
==
0
:
model_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state_copy
}
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
del
optim_state_copy
def
load_checkpoint
(
dire
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
optimizer
:
Optional
[
Colossalai
Optimizer
]
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -56,7 +69,7 @@ def load_checkpoint(dire,
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (
torch.optim.
Optimizer, optional): _description_. Defaults to None.
optimizer (
Colossalai
Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
...
...
@@ -74,3 +87,24 @@ def load_checkpoint(dire,
for
k
,
v
in
model
.
state_dict
().
items
():
if
isinstance
(
v
,
ColoTensor
):
v
.
set_tensor_spec
(
*
mapping
[
k
])
del
mapping
mapping
=
dict
()
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
mapping
[(
k
,
n
)]
=
(
t
.
dist_spec
,
t
.
compute_spec
)
t
.
to_replicate_
()
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if
(
k
,
n
)
not
in
mapping
:
continue
t
.
set_tensor_spec
(
*
mapping
[(
k
,
n
)])
tests/test_utils/test_colo_checkpoint.py
View file @
9e4c6449
...
...
@@ -77,6 +77,18 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
compare_optims
(
optim1
,
optim2
):
state1
=
optim1
.
state_dict
()[
'state'
]
state2
=
optim2
.
state_dict
()[
'state'
]
for
k
,
p1
in
state1
.
items
():
if
k
not
in
state2
:
continue
p2
=
state2
[
k
]
if
isinstance
(
p1
,
ColoTensor
):
assert
isinstance
(
p2
,
ColoTensor
)
assert
torch
.
allclose
(
p1
.
to_replicate_
(),
p2
.
to_replicate_
(),
rtol
=
1e-3
,
atol
=
1e-1
)
def
_run_checkpoint
(
model_name
,
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
...
@@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload
=
model_reload
.
cuda
()
model_reload
.
train
()
colo_optimizer
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model
.
named_parameters
(),
r
=
0.1
))
opt_class
=
torch
.
optim
.
Adam
colo_optimizer
=
ColossalaiOptimizer
(
opt_class
(
model
.
parameters
(),
lr
=
0.1
))
colo_optimizer_reload
=
ColossalaiOptimizer
(
opt_class
(
model_reload
.
parameters
(),
lr
=
0.1
))
run_reload
=
False
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
...
...
@@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
output_reload
=
model_reload
(
data
)
loss
=
criterion
(
output
,
label
)
loss_reload
=
criterion
(
output_reload
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
loss
=
model
(
data
,
label
)
loss
_reload
=
model_reload
(
data
,
label
)
loss
.
backward
()
colo_optimizer
.
step
()
loss_reload
.
backward
()
if
run_reload
:
colo_optimizer_reload
.
zero_grad
()
if
criterion
:
output_reload
=
model_reload
(
data
)
loss_reload
=
criterion
(
output_reload
,
label
)
else
:
loss_reload
=
model_reload
(
data
,
label
)
loss_reload
.
backward
()
colo_optimizer_reload
.
step
()
if
i
>
2
:
break
if
not
os
.
path
.
isdir
(
'./checkpoint'
)
and
rank
==
0
:
os
.
mkdir
(
'./checkpoint'
)
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
None
,
None
)
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
colo_optimizer
,
None
)
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
colo_optimizer_reload
,
None
)
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
None
,
None
)
# Since model is sharded, we merge them before param checking.
for
p
in
model
.
parameters
():
...
...
@@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p
.
to_replicate_
()
check_param_equal
(
model
,
model_reload
)
compare_optims
(
colo_optimizer
,
colo_optimizer_reload
)
if
rank
==
0
:
remove
(
'./checkpoint'
)
...
...
@@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
for
model_name
in
[
'bert'
,
'simple_net'
]:
for
model_name
in
[
'simple_net'
,
'bert'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment