Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3ef3791a
Unverified
Commit
3ef3791a
authored
Jul 14, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 14, 2022
Browse files
[checkpoint] add test for bert and hotfix save bugs (#1297)
parent
bd71e2a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
114 deletions
+110
-114
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+4
-3
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+106
-111
No files found.
colossalai/utils/checkpoint/module_checkpoint.py
View file @
3ef3791a
...
...
@@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
else
:
new_dict
[
k
]
=
v
if
dist
.
get_rank
()
==
0
:
for
k
,
v
in
new_dict
.
items
():
if
isinstance
(
v
,
ColoTensor
):
...
...
@@ -60,7 +61,7 @@ def load_checkpoint(dire,
"""
mapping
=
dict
()
for
k
,
v
in
model
.
named_parame
te
r
s
():
for
k
,
v
in
model
.
state_dict
().
i
te
m
s
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
v
.
to_replicate_
()
...
...
@@ -70,6 +71,6 @@ def load_checkpoint(dire,
# reset tensors to original dist spec.
with
DistSpecManager
.
no_grad
():
for
k
,
v
in
model
.
named_parame
te
r
s
():
for
k
,
v
in
model
.
state_dict
().
i
te
m
s
():
if
isinstance
(
v
,
ColoTensor
):
v
.
set_tensor_spec
(
*
mapping
[
k
])
tests/test_utils/test_colo_checkpoint.py
View file @
3ef3791a
from
abc
import
ABC
,
abstractmethod
import
os
,
shutil
import
torch
import
torch.nn
as
nn
import
pytest
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torch.optim.lr_scheduler
import
MultiplicativeLR
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManage
r
,
ShardSpec
,
ProcessGroup
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTenso
r
,
ShardSpec
,
ProcessGroup
,
DistSpecManager
,
ReplicaSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
class
DummyDataGenerator
(
ABC
):
def
__init__
(
self
,
length
=
10
):
self
.
length
=
length
@
abstractmethod
def
generate
(
self
):
pass
from
colossalai.nn.optimizer
import
ColoOptimizer
def
__iter__
(
self
):
self
.
step
=
0
return
self
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
__next__
(
self
):
if
self
.
step
<
self
.
length
:
self
.
step
+=
1
return
self
.
generate
()
else
:
raise
StopIteration
def
__len__
(
self
):
return
self
.
length
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
class
DummyDataLoader
(
DummyDataGenerator
):
def
__init__
(
self
,
batch_size
,
category
,
feature_size
,
length
=
10
):
super
().
__init__
(
length
)
self
.
batch_size
=
batch_size
self
.
category
=
category
self
.
feature_size
=
feature_size
def
generate
(
self
):
image_dict
=
{}
image_dict
[
'pixel_values'
]
=
torch
.
rand
(
self
.
batch_size
,
self
.
feature_size
,
device
=
get_current_device
())
*
2
-
1
image_dict
[
'label'
]
=
torch
.
randint
(
self
.
category
,
(
self
.
batch_size
,),
dtype
=
torch
.
int64
,
device
=
get_current_device
())
return
image_dict
def
init_1d_col_linear
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
class
MLP
(
nn
.
Module
):
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
__init__
(
self
,
in_features
,
out_features
,
hidden_features
=
None
):
super
().
__init__
()
if
hidden_features
is
None
:
hidden_features
=
out_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
activation
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
init_1d_col_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
:
p
.
set_process_group
(
pg
)
p
.
set_tensor_spec
(
*
spec
)
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'embed'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
if
'proj2'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
def
check_param_equal
(
model
,
torch_model
):
...
...
@@ -103,56 +77,75 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
num_epoch
=
5
warmup_epoch
=
2
def
_
run_checkpoint
(
model_name
,
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
batch
=
3
feature
=
32
category
=
16
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
# set_seed(1)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
feature
,
category
)
model
=
model_builder
(
checkpoint
=
True
)
model_reload
=
model_builder
(
checkpoint
=
True
)
with
ColoInitContext
(
device
=
get_current_device
()):
model_reload
=
MLP
(
feature
,
category
)
if
use_mp_reload
:
if
'bert'
==
model_name
:
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
# num_class = vocab_size = 30524 | (30524, 8)
elif
'word_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = seq_len = 512 | (512, 8)
elif
'position_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = type_vocab_size = 2 | (2, 8)
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
elif
p
.
process_group
.
tp_world_size
()
==
1
:
p
.
redistribute
(
ReplicaSpec
(),
pg
)
elif
"simple_net"
==
model_name
:
init_spec_func
(
model
,
pg
)
model
=
model
.
cuda
()
model
.
train
()
model_reload
=
model_reload
.
cuda
()
if
use_ddp
:
model
=
ColoDDP
(
model
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_reload
.
train
()
init_spec_func
(
model
,
pg
)
if
use_mp_reload
:
init_spec_func
(
model_reload
,
pg
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
None
if
test_scheduler
==
'colossalai_cosine_warmup'
:
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
lr_scheduler_reload
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer_reload
,
total_steps
=
num_epoch
,
warmup_steps
=
warmup_epoch
)
elif
test_scheduler
==
'torch_cosine'
:
lr_scheduler
=
CosineAnnealingLR
(
optimizer
=
optimizer
,
T_max
=
num_epoch
)
lr_scheduler_reload
=
CosineAnnealingLR
(
optimizer
=
optimizer_reload
,
T_max
=
num_epoch
)
elif
test_scheduler
==
'torch_lambda'
:
lr_lambda
=
lambda
epoch
:
0.95
lr_scheduler
=
MultiplicativeLR
(
optimizer
=
optimizer
,
lr_lambda
=
lr_lambda
)
lr_scheduler_reload
=
MultiplicativeLR
(
optimizer
=
optimizer_reload
,
lr_lambda
=
lr_lambda
)
else
:
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
optimizer
,
lr_scheduler
)
# Zero grad
colo_optimizer
.
zero_grad
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
loss
.
backward
()
colo_optimizer
.
step
()
if
i
>
2
:
break
if
not
os
.
path
.
isdir
(
'./checkpoint'
)
and
rank
==
0
:
os
.
mkdir
(
'./checkpoint'
)
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
None
,
None
)
dist
.
barrier
()
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
None
,
None
)
# Since model is sharded, we merge them before param checking.
for
p
in
model
.
parameters
():
...
...
@@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
check_param_equal
(
model
,
model_reload
)
if
rank
==
0
:
remove
(
'./checkpoint'
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
for
model_name
in
[
'bert'
,
'simple_net'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'use_mp_reload'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
#
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
None
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
...
...
@@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
use_mp_reload
=
use_mp_reload
,
test_scheduler
=
test_scheduler
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
test_checkpoint
(
2
,
True
,
False
,
"torch_cosine"
)
test_checkpoint
(
2
,
use_ddp
=
False
,
use_mp_reload
=
True
,
test_scheduler
=
"torch_cosine"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment