Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f92c100d
Unverified
Commit
f92c100d
authored
Jul 19, 2022
by
HELSON
Committed by
GitHub
Jul 19, 2022
Browse files
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
parent
f3ce7b83
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
200 additions
and
82 deletions
+200
-82
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-1
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+9
-0
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+68
-50
colossalai/utils/checkpoint/utils.py
colossalai/utils/checkpoint/utils.py
+50
-0
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+25
-31
tests/test_utils/test_colo_checkpoint_tools.py
tests/test_utils/test_colo_checkpoint_tools.py
+47
-0
No files found.
colossalai/tensor/colo_tensor.py
View file @
f92c100d
...
...
@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
replicated_t
=
self
.
redistribute
(
dist_spec
=
ReplicaSpec
())
return
replicated_t
.
view
(
*
args
)
def
size_global
(
self
,
args
:
Optional
[
int
]
=
None
):
def
size_global
(
self
,
args
:
Optional
[
int
]
=
None
)
->
torch
.
Size
:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
...
...
colossalai/tensor/process_group.py
View file @
f92c100d
...
...
@@ -141,9 +141,18 @@ class ProcessGroup:
def
rank
(
self
):
return
self
.
_rank
def
ranks_in_group
(
self
):
return
self
.
_rank_list
def
world_size
(
self
):
return
self
.
_world_size
def
tp_rank_list
(
self
):
return
self
.
_tp_rank_list
def
dp_rank_list
(
self
):
return
self
.
_dp_rank_list
def
tp_local_rank
(
self
):
return
self
.
_rank
%
self
.
_tp_degree
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
f92c100d
import
torch
import
torch.distributed
as
dist
from
colossalai.tensor
import
ColoTensor
,
DistSpecManager
from
colossalai.tensor
import
ColoTensor
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
co
py
import
copy
from
co
lossalai.utils.checkpoint.utils
import
gather_tensor
,
scatter_tensor
from
typing
import
Optional
...
...
@@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
mapping
=
dict
()
new_dict
=
dict
()
rank
=
dist
.
get_rank
()
model_state
=
model
.
state_dict
()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for
k
,
v
in
model
.
state
_dict
()
.
items
():
for
k
,
v
in
model
_
state
.
items
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
else
:
new_dict
[
k
]
=
v
if
dist
.
get_rank
()
==
0
:
for
k
,
v
in
new_dict
.
items
():
gather_tensor
(
v
)
# gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if
rank
==
0
:
# sanity check
for
k
,
v
in
model_state
.
items
():
if
isinstance
(
v
,
ColoTensor
):
assert
v
.
save_ready
assert
v
.
is_replicate
()
delattr
(
v
,
'save_ready'
)
# model saving
save_state
=
{
'epoch'
:
epoch
,
'model'
:
model_state
}
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model_state
=
{
'epoch'
:
epoch
,
'model'
:
new_dict
}
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
# delete the new dict
del
new_dict
# delete old dicts
del
model_state
# synchronize all the processes
dist
.
barrier
()
optim_state_copy
=
copy
(
optimizer
.
state_dict
())
for
k
,
v
in
optim_state_copy
[
'state'
].
items
():
mapping
=
dict
()
optim_state
=
optimizer
.
state_dict
()
for
k
,
v
in
optim_state
[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
t
.
to_replicate_
()
if
dist
.
get_rank
()
==
0
:
model_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state_copy
}
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
del
optim_state_copy
mapping
[(
k
,
n
)]
=
t
.
dist_spec
gather_tensor
(
t
)
if
rank
==
0
:
save_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state
}
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
# recover colo tensors in rank0
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
assert
hasattr
(
t
,
'save_ready'
)
t
.
set_dist_spec
(
mapping
[(
k
,
n
)])
delattr
(
t
,
'save_ready'
)
del
optim_state
del
mapping
dist
.
barrier
()
def
load_checkpoint
(
dire
,
...
...
@@ -72,39 +87,42 @@ def load_checkpoint(dire,
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
rank
=
dist
.
get_rank
()
mapping
=
dict
()
for
k
,
v
in
model
.
state_dict
().
items
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
v
.
to_replicate_
()
model_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model
.
load_state_dict
(
model_state
[
'model'
])
# reset tensors to original dist spec.
with
DistSpecManager
.
no_grad
():
for
k
,
v
in
model
.
state_dict
().
items
():
if
isinstance
(
v
,
ColoTensor
):
v
.
set_tensor_spec
(
*
mapping
[
k
])
for
n
,
p
in
model
.
named_parameters
():
if
isinstance
(
p
,
ColoTensor
):
mapping
[
n
]
=
p
.
dist_spec
gather_tensor
(
p
)
if
rank
==
0
:
load_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model
.
load_state_dict
(
load_state
[
'model'
])
dist
.
barrier
()
# scatter loaded parameters
for
n
,
p
in
model
.
named_parameters
():
if
isinstance
(
p
,
ColoTensor
):
scatter_tensor
(
p
,
mapping
[
n
])
if
rank
==
0
:
assert
hasattr
(
p
,
'save_ready'
)
delattr
(
p
,
'save_ready'
)
del
mapping
mapping
=
dict
()
mapping
=
dict
()
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
mapping
[(
k
,
n
)]
=
(
t
.
dist_spec
,
t
.
compute_spec
)
t
.
to_replicate_
(
)
mapping
[(
k
,
n
)]
=
t
.
dist_spec
gather_tensor
(
t
)
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
if
rank
==
0
:
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
dist
.
barrier
()
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if
(
k
,
n
)
not
in
mapping
:
continue
t
.
set_tensor_spec
(
*
mapping
[(
k
,
n
)])
scatter_tensor
(
t
,
mapping
[(
k
,
n
)])
del
mapping
colossalai/utils/checkpoint/utils.py
0 → 100644
View file @
f92c100d
import
torch
import
torch.distributed
as
dist
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor.distspec
import
_DistSpec
def
gather_tensor
(
colo_tensor
:
ColoTensor
)
->
None
:
"""Make colo_tensor replicated when the rank is 0
"""
if
not
colo_tensor
.
is_replicate
():
pg
=
colo_tensor
.
get_process_group
()
# for the group which contains rank 0
if
pg
.
tp_rank_list
()[
0
]
==
0
:
old_dist_spec
=
colo_tensor
.
dist_spec
colo_tensor
.
to_replicate_
()
if
dist
.
get_rank
()
!=
0
:
colo_tensor
.
set_dist_spec
(
old_dist_spec
)
# synchronize all processes for unexpected problems
dist
.
barrier
()
if
dist
.
get_rank
()
==
0
:
setattr
(
colo_tensor
,
'save_ready'
,
True
)
# set saving signitrue
def
scatter_tensor
(
colo_tensor
:
ColoTensor
,
dist_spec
:
_DistSpec
)
->
None
:
"""Reversal operation of `gather_tensor`.
"""
if
dist_spec
.
placement
==
'r'
:
dist
.
broadcast
(
colo_tensor
.
data
,
0
)
else
:
global_size
=
colo_tensor
.
size_global
()
if
dist
.
get_rank
()
==
0
:
entire_data
=
colo_tensor
.
data
else
:
entire_data
=
torch
.
empty
(
global_size
,
device
=
colo_tensor
.
device
)
dist
.
broadcast
(
entire_data
,
0
)
if
dist
.
get_rank
()
==
0
:
colo_tensor
.
set_dist_spec
(
dist_spec
)
else
:
rep_tensor
=
ColoTensor
(
entire_data
,
ColoTensorSpec
(
pg
=
colo_tensor
.
get_process_group
(),
compute_attr
=
colo_tensor
.
compute_spec
))
rep_tensor
.
set_dist_spec
(
dist_spec
)
with
torch
.
no_grad
():
colo_tensor
.
data
.
copy_
(
rep_tensor
.
data
)
# synchronize all processes for unexpected problems
dist
.
barrier
()
tests/test_utils/test_colo_checkpoint.py
View file @
f92c100d
import
os
,
shutil
import
torch
import
pytest
from
copy
import
deepcopy
from
functools
import
partial
import
torch.multiprocessing
as
mp
...
...
@@ -15,8 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ProcessGroup
,
DistSpecManager
,
ReplicaSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ProcessGroup
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
...
...
@@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
def
check_param_equal
(
model
,
torch_model
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
torch
.
all
close
(
torch_p
,
p
,
rtol
=
1e-3
,
atol
=
1e-1
)
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_
parameters
(),
torch_model
.
named_
parameters
()):
assert
torch
.
all
(
p
.
data
==
tp
.
data
),
"{} went wrong.
\n
{} vs {}
\n
{}"
.
format
(
n
,
p
,
tp
,
p
.
shape
)
def
remove
(
path
):
...
...
@@ -84,9 +84,13 @@ def compare_optims(optim1, optim2):
if
k
not
in
state2
:
continue
p2
=
state2
[
k
]
if
isinstance
(
p1
,
ColoTensor
):
assert
isinstance
(
p2
,
ColoTensor
)
assert
torch
.
allclose
(
p1
.
to_replicate_
(),
p2
.
to_replicate_
(),
rtol
=
1e-3
,
atol
=
1e-1
)
for
n
,
t1
in
p1
.
items
():
if
n
not
in
p2
:
continue
t2
=
p2
[
n
]
if
isinstance
(
t1
,
ColoTensor
):
assert
isinstance
(
t2
,
ColoTensor
)
assert
torch
.
allclose
(
t1
,
t2
,
rtol
=
0
,
atol
=
0
)
def
_run_checkpoint
(
model_name
,
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
...
...
@@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# set_seed(1)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
model_reload
=
model_builder
(
checkpoint
=
True
)
if
use_mp_reload
:
if
'bert'
==
model_name
:
...
...
@@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
elif
p
.
process_group
.
tp_world_size
()
==
1
:
p
.
redistribute
(
ReplicaSpec
(),
pg
)
p
.
set_process_group
(
pg
)
elif
"simple_net"
==
model_name
:
init_spec_func
(
model
,
pg
)
model_reload
=
deepcopy
(
model
)
model
=
model
.
cuda
()
model
.
train
()
model
.
eval
()
model_reload
=
model_reload
.
cuda
()
model_reload
.
train
()
model_reload
.
eval
()
opt_class
=
torch
.
optim
.
Adam
colo_optimizer
=
ColossalaiOptimizer
(
opt_class
(
model
.
parameters
(),
lr
=
0.1
))
colo_optimizer_reload
=
ColossalaiOptimizer
(
opt_class
(
model_reload
.
parameters
(),
lr
=
0.1
))
run_reload
=
False
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
# Zero grad
colo_optimizer
.
zero_grad
()
colo_optimizer_reload
.
zero_grad
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
...
...
@@ -155,43 +159,33 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
loss
.
backward
()
loss_reload
.
backward
()
if
run_reload
:
colo_optimizer_reload
.
zero_grad
()
if
criterion
:
output_reload
=
model_reload
(
data
)
loss_reload
=
criterion
(
output_reload
,
label
)
else
:
loss_reload
=
model_reload
(
data
,
label
)
loss_reload
.
backward
()
colo_optimizer_reload
.
step
()
colo_optimizer
.
step
()
colo_optimizer_reload
.
step
()
if
i
>
2
:
break
if
not
os
.
path
.
isdir
(
'./checkpoint'
)
and
rank
==
0
:
os
.
mkdir
(
'./checkpoint'
)
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
colo_optimizer
,
None
)
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
colo_optimizer_reload
,
None
)
dist
.
barrier
()
# Since model is sharded, we merge them before param checking.
for
p
in
model
.
parameters
():
p
.
to_replicate_
()
for
p
in
model_reload
.
parameters
():
p
.
to_replicate_
()
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
colo_optimizer
,
None
)
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
colo_optimizer_reload
,
None
)
check_param_equal
(
model
,
model_reload
)
compare_optims
(
colo_optimizer
,
colo_optimizer_reload
)
if
rank
==
0
:
remove
(
'./checkpoint'
)
dist
.
barrier
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
for
model_name
in
[
'simple_net'
,
'bert'
]:
# TODO(haichen) add BERT in the test
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for
model_name
in
[
'simple_net'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
...
...
tests/test_utils/test_colo_checkpoint_tools.py
0 → 100644
View file @
f92c100d
import
torch
import
pytest
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ProcessGroup
,
ColoTensorSpec
from
colossalai.utils.checkpoint.utils
import
gather_tensor
,
scatter_tensor
from
tests.test_tensor._utils
import
tensor_shard_equal
def
run_dist
(
rank
,
world_size
,
port
,
dp_degree
,
tp_degree
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
dp_degree
=
dp_degree
,
tp_degree
=
tp_degree
)
x
=
torch
.
randn
(
4
,
4
,
device
=
get_current_device
())
param
=
ColoTensor
(
torch
.
nn
.
Parameter
(
x
),
spec
=
ColoTensorSpec
(
pg
))
spec
=
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)
param
.
set_tensor_spec
(
*
spec
)
gather_tensor
(
param
)
if
dist
.
get_rank
()
==
0
:
assert
torch
.
allclose
(
x
,
param
.
data
,
rtol
=
0
,
atol
=
0
)
else
:
assert
tensor_shard_equal
(
x
,
param
.
data
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
dist
.
barrier
()
scatter_tensor
(
param
,
spec
[
0
])
assert
tensor_shard_equal
(
x
,
param
.
data
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
param
.
requires_grad
is
True
dist
.
barrier
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
dp_degree
=
2
,
tp_degree
=
world_size
//
2
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_checkpoint
(
world_size
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment