Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3ef3791a
Unverified
Commit
3ef3791a
authored
Jul 14, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 14, 2022
Browse files
[checkpoint] add test for bert and hotfix save bugs (#1297)
parent
bd71e2a8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
114 deletions
+110
-114
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+4
-3
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+106
-111
No files found.
colossalai/utils/checkpoint/module_checkpoint.py
View file @
3ef3791a
...
@@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
...
@@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
if
isinstance
(
v
,
ColoTensor
):
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
else
:
new_dict
[
k
]
=
v
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
for
k
,
v
in
new_dict
.
items
():
for
k
,
v
in
new_dict
.
items
():
if
isinstance
(
v
,
ColoTensor
):
if
isinstance
(
v
,
ColoTensor
):
...
@@ -60,7 +61,7 @@ def load_checkpoint(dire,
...
@@ -60,7 +61,7 @@ def load_checkpoint(dire,
"""
"""
mapping
=
dict
()
mapping
=
dict
()
for
k
,
v
in
model
.
named_parame
te
r
s
():
for
k
,
v
in
model
.
state_dict
().
i
te
m
s
():
if
isinstance
(
v
,
ColoTensor
):
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
v
.
to_replicate_
()
v
.
to_replicate_
()
...
@@ -70,6 +71,6 @@ def load_checkpoint(dire,
...
@@ -70,6 +71,6 @@ def load_checkpoint(dire,
# reset tensors to original dist spec.
# reset tensors to original dist spec.
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
k
,
v
in
model
.
named_parame
te
r
s
():
for
k
,
v
in
model
.
state_dict
().
i
te
m
s
():
if
isinstance
(
v
,
ColoTensor
):
if
isinstance
(
v
,
ColoTensor
):
v
.
set_tensor_spec
(
*
mapping
[
k
])
v
.
set_tensor_spec
(
*
mapping
[
k
])
tests/test_utils/test_colo_checkpoint.py
View file @
3ef3791a
from
abc
import
ABC
,
abstractmethod
import
os
,
shutil
import
os
,
shutil
import
torch
import
torch
import
torch.nn
as
nn
import
pytest
import
pytest
from
functools
import
partial
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torch.optim.lr_scheduler
import
MultiplicativeLR
from
torch.optim.lr_scheduler
import
MultiplicativeLR
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
import
colossalai
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManage
r
,
ShardSpec
,
ProcessGroup
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTenso
r
,
ShardSpec
,
ProcessGroup
,
DistSpecManager
,
ReplicaSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
ColoOptimizer
class
DummyDataGenerator
(
ABC
):
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
__init__
(
self
,
length
=
10
):
self
.
length
=
length
@
abstractmethod
def
generate
(
self
):
pass
def
__iter__
(
self
):
self
.
step
=
0
return
self
def
__next__
(
self
):
if
self
.
step
<
self
.
length
:
self
.
step
+=
1
return
self
.
generate
()
else
:
raise
StopIteration
def
__len__
(
self
):
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
return
self
.
length
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
class
DummyDataLoader
(
DummyDataGenerator
):
def
__init__
(
self
,
batch_size
,
category
,
feature_size
,
length
=
10
):
super
().
__init__
(
length
)
self
.
batch_size
=
batch_size
self
.
category
=
category
self
.
feature_size
=
feature_size
def
generate
(
self
):
def
init_1d_col_linear
(
weight
,
pg
):
image_dict
=
{}
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
image_dict
[
'pixel_values'
]
=
torch
.
rand
(
self
.
batch_size
,
self
.
feature_size
,
device
=
get_current_device
())
*
2
-
1
weight
.
set_process_group
(
pg
)
image_dict
[
'label'
]
=
torch
.
randint
(
self
.
category
,
(
self
.
batch_size
,),
weight
.
set_tensor_spec
(
*
spec
)
dtype
=
torch
.
int64
,
device
=
get_current_device
())
return
image_dict
class
MLP
(
nn
.
Module
):
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
__init__
(
self
,
in_features
,
out_features
,
hidden_features
=
None
):
super
().
__init__
()
if
hidden_features
is
None
:
hidden_features
=
out_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
activation
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
init_1d_col_embedding
(
weight
,
pg
):
x
=
self
.
fc1
(
x
)
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
x
=
self
.
activation
(
x
)
weight
.
set_process_group
(
pg
)
x
=
self
.
fc2
(
x
)
weight
.
set_tensor_spec
(
*
spec
)
return
x
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
name
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
if
'weight'
in
n
:
continue
p
.
set_process_group
(
pg
)
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_tensor_spec
(
*
spec
)
init_1d_col_embedding
(
p
,
pg
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
if
'proj2'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
def
check_param_equal
(
model
,
torch_model
):
def
check_param_equal
(
model
,
torch_model
):
...
@@ -103,56 +77,75 @@ def remove(path):
...
@@ -103,56 +77,75 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
def
_
run_checkpoint
(
model_name
,
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
num_epoch
=
5
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
warmup_epoch
=
2
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
batch
=
3
rank
=
torch
.
distributed
.
get_rank
()
feature
=
32
world_size
=
torch
.
distributed
.
get_world_size
()
category
=
16
# set_seed(1)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
feature
,
category
)
model
=
model_builder
(
checkpoint
=
True
)
model_reload
=
model_builder
(
checkpoint
=
True
)
with
ColoInitContext
(
device
=
get_current_device
()):
if
use_mp_reload
:
model_reload
=
MLP
(
feature
,
category
)
if
'bert'
==
model_name
:
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
# num_class = vocab_size = 30524 | (30524, 8)
elif
'word_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = seq_len = 512 | (512, 8)
elif
'position_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = type_vocab_size = 2 | (2, 8)
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
elif
p
.
process_group
.
tp_world_size
()
==
1
:
p
.
redistribute
(
ReplicaSpec
(),
pg
)
elif
"simple_net"
==
model_name
:
init_spec_func
(
model
,
pg
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
train
()
model_reload
=
model_reload
.
cuda
()
model_reload
=
model_reload
.
cuda
()
if
use_ddp
:
model_reload
.
train
()
model
=
ColoDDP
(
model
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
init_spec_func
(
model
,
pg
)
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
if
use_mp_reload
:
init_spec_func
(
model_reload
,
pg
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
# Zero grad
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
colo_optimizer
.
zero_grad
()
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
data
=
data
.
to
(
get_current_device
())
eps
=
1e-08
,
label
=
label
.
to
(
get_current_device
())
weight_decay
=
0
)
# Bcast rank0 data to all processes
lr_scheduler
=
None
if
criterion
:
if
test_scheduler
==
'colossalai_cosine_warmup'
:
output
=
model
(
data
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
loss
=
criterion
(
output
,
label
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
elif
test_scheduler
==
'torch_cosine'
:
lr_scheduler
=
CosineAnnealingLR
(
optimizer
=
optimizer
,
T_max
=
num_epoch
)
lr_scheduler_reload
=
CosineAnnealingLR
(
optimizer
=
optimizer_reload
,
T_max
=
num_epoch
)
elif
test_scheduler
==
'torch_lambda'
:
lr_lambda
=
lambda
epoch
:
0.95
lr_scheduler
=
MultiplicativeLR
(
optimizer
=
optimizer
,
lr_lambda
=
lr_lambda
)
lr_scheduler_reload
=
MultiplicativeLR
(
optimizer
=
optimizer_reload
,
lr_lambda
=
lr_lambda
)
else
:
else
:
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
output
=
model
(
data
,
label
)
loss
=
output
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
optimizer
,
lr_scheduler
)
loss
.
backward
()
colo_optimizer
.
step
()
if
i
>
2
:
break
if
not
os
.
path
.
isdir
(
'./checkpoint'
)
and
rank
==
0
:
os
.
mkdir
(
'./checkpoint'
)
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
None
,
None
)
dist
.
barrier
()
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
None
,
None
)
# Since model is sharded, we merge them before param checking.
# Since model is sharded, we merge them before param checking.
for
p
in
model
.
parameters
():
for
p
in
model
.
parameters
():
...
@@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
...
@@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
check_param_equal
(
model
,
model_reload
)
check_param_equal
(
model
,
model_reload
)
if
rank
==
0
:
remove
(
'./checkpoint'
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
use_ddp
and
world_size
==
1
:
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
for
model_name
in
[
'bert'
,
'simple_net'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'use_mp_reload'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'use_mp_reload'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
#
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
None
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
run_func
=
partial
(
run_dist
,
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
world_size
=
world_size
,
port
=
free_port
(),
port
=
free_port
(),
...
@@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
...
@@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
use_mp_reload
=
use_mp_reload
,
use_mp_reload
=
use_mp_reload
,
test_scheduler
=
test_scheduler
)
test_scheduler
=
test_scheduler
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_checkpoint
(
2
,
True
,
False
,
"torch_cosine"
)
test_checkpoint
(
2
,
use_ddp
=
False
,
use_mp_reload
=
True
,
test_scheduler
=
"torch_cosine"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment