Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f38006ea
Unverified
Commit
f38006ea
authored
Jul 06, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 06, 2022
Browse files
[checkpoint] checkpoint for ColoTensor Model (#1196)
parent
291e22aa
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
292 additions
and
1 deletion
+292
-1
colossalai/utils/checkpoint/__init__.py
colossalai/utils/checkpoint/__init__.py
+3
-0
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+73
-0
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+5
-1
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+211
-0
No files found.
colossalai/utils/checkpoint/__init__.py
0 → 100644
View file @
f38006ea
from
.module_checkpoint
import
save_checkpoint
,
load_checkpoint
__all__
=
[
'save_checkpoint'
,
'load_checkpoint'
]
colossalai/utils/checkpoint/module_checkpoint.py
0 → 100644
View file @
f38006ea
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
collections
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
def
save_checkpoint
(
dire
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
*
args
,
**
kwargs
):
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state
=
{
'epoch'
:
epoch
,
'model'
:
colo_state_dict
(
model
,
state_dict_func
=
nn
.
Module
.
state_dict
)
}
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
lr_scheduler_dict
=
lr_scheduler
.
state_dict
()
lr_scheduler_dict
[
'after_scheduler'
]
=
lr_scheduler_dict
[
'after_scheduler'
].
state_dict
()
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler_dict
}
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
def
load_checkpoint
(
dire
,
epoch
:
int
,
rank
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
*
args
,
**
kwargs
):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model_state
[
'model'
]
=
collections
.
OrderedDict
([(
k
.
split
(
'.'
,
1
)[
1
],
v
)
for
k
,
v
in
model_state
[
'model'
].
items
()])
model
.
load_state_dict
(
model_state
[
'model'
])
optim_state
=
torch
.
load
(
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
rank
))
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
after_scheduler_dict
=
lr_scheduler_dict
[
'after_scheduler'
]
lr_scheduler_dict
[
'after_scheduler'
]
=
_CosineAnnealingLR
(
optimizer
,
after_scheduler_dict
[
'T_max'
],
after_scheduler_dict
[
'eta_min'
],
after_scheduler_dict
[
'last_epoch'
]
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
colossalai/utils/model/colo_init_context.py
View file @
f38006ea
...
...
@@ -38,15 +38,18 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
# build param to spec mapping
mapping1
=
dict
()
mapping2
=
dict
()
mapping3
=
dict
()
# gather all params
has_dist_parameter
=
False
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
ColoParameter
)
and
param
.
has_compute_spec
()
:
if
isinstance
(
param
,
ColoParameter
):
has_dist_parameter
=
True
mapping1
[
id
(
param
)]
=
copy
(
param
.
dist_spec
)
mapping2
[
id
(
param
)]
=
copy
(
param
.
compute_spec
)
mapping3
[
id
(
param
)]
=
param
.
get_process_group
()
param
.
set_dist_spec
(
distspec
.
replicate
())
param
.
process_group
=
None
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
...
...
@@ -64,6 +67,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
if
param_id
in
mapping1
:
dist_spec
=
mapping1
[
id
(
param
)]
compute_spec
=
mapping2
[
id
(
param
)]
param
.
process_group
=
mapping3
[
id
(
param
)]
param
.
set_tensor_spec
(
dist_spec
,
compute_spec
)
return
ret
...
...
tests/test_utils/test_colo_checkpoint.py
0 → 100644
View file @
f38006ea
from
abc
import
ABC
,
abstractmethod
import
os
,
sys
,
shutil
import
torch
import
torch.nn
as
nn
import
pytest
import
copy
import
operator
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
,
ColoTensor
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
class
DummyDataGenerator
(
ABC
):
def
__init__
(
self
,
length
=
10
):
self
.
length
=
length
@
abstractmethod
def
generate
(
self
):
pass
def
__iter__
(
self
):
self
.
step
=
0
return
self
def
__next__
(
self
):
if
self
.
step
<
self
.
length
:
self
.
step
+=
1
return
self
.
generate
()
else
:
raise
StopIteration
def
__len__
(
self
):
return
self
.
length
class
DummyDataLoader
(
DummyDataGenerator
):
batch_size
=
128
category
=
16
feature_size
=
256
def
generate
(
self
):
image_dict
=
{}
image_dict
[
'pixel_values'
]
=
torch
.
rand
(
DummyDataLoader
.
batch_size
,
DummyDataLoader
.
feature_size
,
device
=
get_current_device
())
*
2
-
1
image_dict
[
'label'
]
=
torch
.
randint
(
DummyDataLoader
.
category
,
(
DummyDataLoader
.
batch_size
,),
dtype
=
torch
.
int64
,
device
=
get_current_device
())
return
image_dict
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
hidden_features
=
None
):
super
().
__init__
()
if
hidden_features
is
None
:
hidden_features
=
out_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
activation
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
:
p
.
set_process_group
(
pg
)
p
.
set_tensor_spec
(
*
spec
)
def
check_param_equal
(
model
,
torch_model
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
torch
.
allclose
(
torch_p
,
p
,
rtol
=
1e-3
,
atol
=
1e-1
)
def
remove
(
path
):
""" param <path> could either be relative or absolute. """
if
os
.
path
.
isfile
(
path
)
or
os
.
path
.
islink
(
path
):
os
.
remove
(
path
)
elif
os
.
path
.
isdir
(
path
):
shutil
.
rmtree
(
path
)
else
:
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
pg
):
train_dataloader
=
DummyDataLoader
(
length
=
16
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
256
,
16
,
64
)
model_reload
=
MLP
(
256
,
16
,
64
)
model_ref
=
MLP
(
256
,
16
,
64
)
model
=
model
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_ref
=
model_ref
.
cuda
()
if
use_ddp
:
model
=
ColoDDP
(
model
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_ref
=
ColoDDP
(
model_ref
,
pg
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
20
,
warmup_steps
=
5
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
20
,
warmup_steps
=
5
)
lr_scheduler_ref
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_ref
,
total_steps
=
20
,
warmup_steps
=
5
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model_ref
,
pg
)
for
epoch
in
range
(
0
,
20
):
if
epoch
<=
test_epoch
:
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
model
.
zero_grad
()
else
:
optimizer
.
zero_grad
()
logits
=
model
(
image_dict
[
'pixel_values'
])
loss
=
criterion
(
logits
,
image_dict
[
'label'
])
if
use_ddp
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
optimizer
.
step
()
if
epoch
==
test_epoch
:
for
ref_p
,
p
in
zip
(
model_ref
.
parameters
(),
model
.
parameters
()):
ref_p
.
data
.
copy_
(
p
)
optimizer_ref
=
copy
.
deepcopy
(
optimizer
)
lr_scheduler_ref
=
copy
.
deepcopy
(
lr_scheduler
)
check_param_equal
(
model
,
model_ref
)
save_checkpoint
(
'./checkpoint'
,
epoch
,
model
,
optimizer
,
lr_scheduler
)
dist
.
barrier
()
else
:
if
epoch
==
test_epoch
+
1
:
load_checkpoint
(
'./checkpoint'
,
test_epoch
,
dist
.
get_rank
(),
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
init_spec_func
(
model_reload
,
pg
)
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
model_ref
.
zero_grad
()
model_reload
.
zero_grad
()
else
:
optimizer_ref
.
zero_grad
()
optimizer_reload
.
zero_grad
()
logits_ref
=
model_ref
(
image_dict
[
'pixel_values'
])
logits_reload
=
model_reload
(
image_dict
[
'pixel_values'
])
loss_ref
=
criterion
(
logits_ref
,
image_dict
[
'label'
])
loss_reload
=
criterion
(
logits_reload
,
image_dict
[
'label'
])
if
use_ddp
:
model_ref
.
backward
(
loss_ref
)
model_reload
.
backward
(
loss_reload
)
else
:
loss_ref
.
backward
()
loss_reload
.
backward
()
optimizer_ref
.
step
()
optimizer_reload
.
step
()
lr_scheduler
.
step
()
check_param_equal
(
model_ref
,
model_reload
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
test_epoch
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
,
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'test_epoch'
,
[
1
,
2
,
3
])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
test_epoch
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
test_epoch
=
test_epoch
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
test_checkpoint
(
4
,
True
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment