Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c92f84fc
Unverified
Commit
c92f84fc
authored
Jul 12, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 12, 2022
Browse files
[tensor] distributed checkpointing for parameters (#1240)
parent
49114d8d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
160 deletions
+77
-160
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+2
-2
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+1
-1
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+31
-36
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+2
-47
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+14
-0
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+27
-74
No files found.
colossalai/tensor/colo_tensor.py
View file @
c92f84fc
...
...
@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
self
.
_redistribute
(
dist_spec
)
def
set_tensor_spec
(
self
,
dist_spec
,
compute_spec
):
if
dist_spec
:
if
dist_spec
is
not
None
:
assert
isinstance
(
dist_spec
,
_DistSpec
),
f
"
{
type
(
dist_spec
)
}
"
self
.
set_dist_spec
(
dist_spec
)
if
compute_spec
:
if
compute_spec
is
not
None
:
self
.
compute_spec
=
compute_spec
def
has_compute_pattern
(
self
,
compute_pattern
):
...
...
colossalai/tensor/distspec.py
View file @
c92f84fc
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
,
Optional
__all__
=
[
'replicate'
,
'shard'
]
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
c92f84fc
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
collections
import
inspect
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
def
filter_dict
(
dict_to_filter
,
thing_with_kwargs
):
sig
=
inspect
.
signature
(
thing_with_kwargs
)
filter_keys
=
[
param
.
name
for
param
in
sig
.
parameters
.
values
()
if
param
.
kind
==
param
.
POSITIONAL_OR_KEYWORD
]
filter_dict
=
{}
for
filter_key
in
filter_keys
:
if
filter_key
in
dict_to_filter
:
filter_dict
[
filter_key
]
=
dict_to_filter
[
filter_key
]
return
filter_dict
from
colossalai.tensor
import
ColoTensor
,
DistSpecManager
def
save_checkpoint
(
dire
:
str
,
...
...
@@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
model_state
=
{
'epoch'
:
epoch
,
'model'
:
model
.
state_dict
()}
mapping
=
dict
()
new_dict
=
dict
()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for
k
,
v
in
model
.
state_dict
().
items
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
for
k
,
v
in
new_dict
.
items
():
if
isinstance
(
v
,
ColoTensor
):
assert
v
.
is_replicate
()
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
# 1. convert SHARD ColoTensor to REPLICATE
# only rank 0 saves the REPLICATE tensors.
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
()}
model_state
=
{
'epoch'
:
epoch
,
'model'
:
new_dict
}
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
# delete the new dict
del
new_dict
def
load_checkpoint
(
dire
,
epoch
:
int
,
rank
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
...
...
@@ -62,19 +58,18 @@ def load_checkpoint(dire,
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
mapping
=
dict
()
for
k
,
v
in
model
.
named_parameters
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
v
.
to_replicate_
()
model_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model_state
[
'model'
]
=
collections
.
OrderedDict
([(
k
.
split
(
'.'
,
1
)[
1
],
v
)
for
k
,
v
in
model_state
[
'model'
].
items
()])
model
.
load_state_dict
(
model_state
[
'model'
])
optim_state
=
torch
.
load
(
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
rank
))
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
if
'after_scheduler_type'
in
lr_scheduler_dict
:
after_scheduler_type
=
lr_scheduler_dict
.
pop
(
'after_scheduler_type'
)
after_scheduler_dict
=
lr_scheduler_dict
.
pop
(
'after_scheduler_dict'
)
reload_scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
after_scheduler_type
)
filtered_dict
=
filter_dict
(
after_scheduler_dict
,
reload_scheduler
)
lr_scheduler_dict
[
'after_scheduler'
]
=
reload_scheduler
(
optimizer
,
**
filtered_dict
,
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
# reset tensors to original dist spec.
with
DistSpecManager
.
no_grad
():
for
k
,
v
in
model
.
named_parameters
():
if
isinstance
(
v
,
ColoTensor
):
v
.
set_tensor_spec
(
*
mapping
[
k
])
colossalai/utils/model/colo_init_context.py
View file @
c92f84fc
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
ProcessGroup
,
ReplicaSpec
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
ColoLinear
,
ColoEmbedding
from
copy
import
copy
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
from
functools
import
partialmethod
# find named_params includes replica
...
...
@@ -34,47 +31,6 @@ def ColoModulize(module):
module
.
_colo_visited
=
True
def
colo_state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
,
state_dict_func
=
None
):
# build param to spec mapping
mapping1
=
dict
()
mapping2
=
dict
()
mapping3
=
dict
()
# gather all params
has_dist_parameter
=
False
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
has_dist_parameter
=
True
mapping1
[
id
(
param
)]
=
copy
(
param
.
dist_spec
)
mapping2
[
id
(
param
)]
=
copy
(
param
.
compute_spec
)
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
if
param
.
get_process_group
()
is
None
:
param
.
process_group
=
ProcessGroup
()
param
.
set_dist_spec
(
distspec
.
replicate
())
mapping3
[
id
(
param
)]
=
param
.
get_process_group
()
param
.
process_group
=
None
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert
not
(
keep_vars
and
has_dist_parameter
),
'keep_vars cannot be True when there are distributed ColoParameters.'
ret
=
state_dict_func
(
self
,
destination
,
prefix
,
keep_vars
)
# recover
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
param_id
=
id
(
param
)
if
param_id
in
mapping1
:
dist_spec
=
mapping1
[
id
(
param
)]
compute_spec
=
mapping2
[
id
(
param
)]
param
.
process_group
=
mapping3
[
id
(
param
)]
param
.
set_tensor_spec
(
dist_spec
,
compute_spec
)
return
ret
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
lazy_memory_allocate
:
bool
=
False
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
)):
...
...
@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module
(
torch
.
nn
.
Embedding
,
ColoEmbedding
())
def
_pre_context_exec
(
self
):
self
.
state_dict_func
=
nn
.
Module
.
state_dict
nn
.
Module
.
state_dict
=
partialmethod
(
colo_state_dict
,
state_dict_func
=
self
.
state_dict_func
)
pass
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
"""
...
...
tests/test_tensor/test_tensor.py
View file @
c92f84fc
...
...
@@ -122,6 +122,19 @@ def _run_redistributed(world_size):
assert
t1
.
is_replicate
()
def
_run_set_tensor_spec
(
world_size
):
if
world_size
!=
4
:
return
pg
=
ProcessGroup
(
tp_degree
=
2
,
dp_degree
=
2
)
spec1
=
ColoTensorSpec
(
pg
)
t1
=
ColoTensor
.
from_torch_tensor
(
torch
.
randn
(
2
,
3
,
4
),
spec1
)
dist_spec2
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
None
)
assert
t1
.
is_replicate
()
t1
.
set_dist_spec
(
*
dist_spec2
)
assert
t1
.
is_shard_1dcol
()
def
run_dist_tests
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
...
...
@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
_run_operand
(
world_size
)
_run_wrapped_tensor_func
()
_run_redistributed
(
world_size
)
_run_set_tensor_spec
(
world_size
)
@
pytest
.
mark
.
dist
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
c92f84fc
...
...
@@ -3,7 +3,6 @@ import os, shutil
import
torch
import
torch.nn
as
nn
import
pytest
import
copy
from
functools
import
partial
import
torch.multiprocessing
as
mp
...
...
@@ -104,7 +103,7 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
):
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
num_epoch
=
5
warmup_epoch
=
2
...
...
@@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
feature
=
32
category
=
16
train_dataloader
=
DummyDataLoader
(
batch
,
category
,
feature
,
length
=
16
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
feature
,
category
)
with
ColoInitContext
(
device
=
get_current_device
()):
model_reload
=
MLP
(
feature
,
category
)
model_ref
=
MLP
(
feature
,
category
)
model
=
model
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_ref
=
model_ref
.
cuda
()
if
use_ddp
:
model
=
ColoDDP
(
model
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_ref
=
ColoDDP
(
model_ref
,
pg
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model_ref
,
pg
)
if
use_mp_reload
:
init_spec_func
(
model_reload
,
pg
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
None
if
test_scheduler
==
'colossalai_cosine_warmup'
:
...
...
@@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
else
:
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
for
epoch
in
range
(
0
,
num_epoch
):
if
epoch
<=
test_epoch
:
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
model
.
zero_grad
()
else
:
optimizer
.
zero_grad
()
logits
=
model
(
image_dict
[
'pixel_values'
])
loss
=
criterion
(
logits
,
image_dict
[
'label'
])
if
use_ddp
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
optimizer
.
step
()
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
optimizer
,
lr_scheduler
)
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
if
epoch
==
test_epoch
:
for
ref_p
,
p
in
zip
(
model_ref
.
parameters
(),
model
.
parameters
()):
ref_p
.
data
.
copy_
(
p
)
optimizer_ref
=
copy
.
deepcopy
(
optimizer
)
# Since model is sharded, we merge them before param checking.
for
p
in
model
.
parameters
():
p
.
to_replicate_
()
check_param_equal
(
model
,
model_ref
)
save_checkpoint
(
'./checkpoint'
,
epoch
,
model
,
optimizer
,
lr_scheduler
)
dist
.
barrier
()
else
:
if
epoch
==
test_epoch
+
1
:
load_checkpoint
(
'./checkpoint'
,
test_epoch
,
dist
.
get_rank
(),
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
init_spec_func
(
model_reload
,
pg
)
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
model_ref
.
zero_grad
()
model_reload
.
zero_grad
()
else
:
optimizer_ref
.
zero_grad
()
optimizer_reload
.
zero_grad
()
logits_ref
=
model_ref
(
image_dict
[
'pixel_values'
])
logits_reload
=
model_reload
(
image_dict
[
'pixel_values'
])
loss_ref
=
criterion
(
logits_ref
,
image_dict
[
'label'
])
loss_reload
=
criterion
(
logits_reload
,
image_dict
[
'label'
])
if
use_ddp
:
model_ref
.
backward
(
loss_ref
)
model_reload
.
backward
(
loss_reload
)
else
:
loss_ref
.
backward
()
loss_reload
.
backward
()
optimizer_ref
.
step
()
optimizer_reload
.
step
()
lr_scheduler
.
step
()
for
p
in
model_reload
.
parameters
():
p
.
to_replicate_
()
check_param_equal
(
model
_ref
,
model_reload
)
check_param_equal
(
model
,
model_reload
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
test_epoch
,
test_scheduler
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
=
test_epoch
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'
test_epoch'
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'
use_mp_reload'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
test_epoch
,
test_scheduler
):
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
test_epoch
=
test_epoch
,
use_mp_reload
=
use_mp_reload
,
test_scheduler
=
test_scheduler
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
test_checkpoint
(
4
,
True
,
1
,
"colossalai_cosine_warmup
"
)
test_checkpoint
(
2
,
True
,
False
,
"torch_cosine
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment