Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c92f84fc
Unverified
Commit
c92f84fc
authored
Jul 12, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 12, 2022
Browse files
[tensor] distributed checkpointing for parameters (#1240)
parent
49114d8d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
160 deletions
+77
-160
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+2
-2
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+1
-1
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+31
-36
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+2
-47
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+14
-0
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+27
-74
No files found.
colossalai/tensor/colo_tensor.py
View file @
c92f84fc
...
@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
...
@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
self
.
_redistribute
(
dist_spec
)
self
.
_redistribute
(
dist_spec
)
def
set_tensor_spec
(
self
,
dist_spec
,
compute_spec
):
def
set_tensor_spec
(
self
,
dist_spec
,
compute_spec
):
if
dist_spec
:
if
dist_spec
is
not
None
:
assert
isinstance
(
dist_spec
,
_DistSpec
),
f
"
{
type
(
dist_spec
)
}
"
assert
isinstance
(
dist_spec
,
_DistSpec
),
f
"
{
type
(
dist_spec
)
}
"
self
.
set_dist_spec
(
dist_spec
)
self
.
set_dist_spec
(
dist_spec
)
if
compute_spec
:
if
compute_spec
is
not
None
:
self
.
compute_spec
=
compute_spec
self
.
compute_spec
=
compute_spec
def
has_compute_pattern
(
self
,
compute_pattern
):
def
has_compute_pattern
(
self
,
compute_pattern
):
...
...
colossalai/tensor/distspec.py
View file @
c92f84fc
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
,
Optional
__all__
=
[
'replicate'
,
'shard'
]
__all__
=
[
'replicate'
,
'shard'
]
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
c92f84fc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
collections
from
colossalai.tensor
import
ColoTensor
,
DistSpecManager
import
inspect
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
def
filter_dict
(
dict_to_filter
,
thing_with_kwargs
):
sig
=
inspect
.
signature
(
thing_with_kwargs
)
filter_keys
=
[
param
.
name
for
param
in
sig
.
parameters
.
values
()
if
param
.
kind
==
param
.
POSITIONAL_OR_KEYWORD
]
filter_dict
=
{}
for
filter_key
in
filter_keys
:
if
filter_key
in
dict_to_filter
:
filter_dict
[
filter_key
]
=
dict_to_filter
[
filter_key
]
return
filter_dict
def
save_checkpoint
(
dire
:
str
,
def
save_checkpoint
(
dire
:
str
,
...
@@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
...
@@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
"""
model_state
=
{
'epoch'
:
epoch
,
'model'
:
model
.
state_dict
()}
mapping
=
dict
()
new_dict
=
dict
()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for
k
,
v
in
model
.
state_dict
().
items
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
new_dict
[
k
]
=
v
.
to_replicate
().
detach
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
for
k
,
v
in
new_dict
.
items
():
if
isinstance
(
v
,
ColoTensor
):
assert
v
.
is_replicate
()
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
model_state
=
{
'epoch'
:
epoch
,
'model'
:
new_dict
}
# 1. convert SHARD ColoTensor to REPLICATE
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
# only rank 0 saves the REPLICATE tensors.
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
()}
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
# delete the new dict
del
new_dict
def
load_checkpoint
(
dire
,
def
load_checkpoint
(
dire
,
epoch
:
int
,
epoch
:
int
,
rank
:
int
,
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
...
@@ -62,19 +58,18 @@ def load_checkpoint(dire,
...
@@ -62,19 +58,18 @@ def load_checkpoint(dire,
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
"""
mapping
=
dict
()
for
k
,
v
in
model
.
named_parameters
():
if
isinstance
(
v
,
ColoTensor
):
mapping
[
k
]
=
(
v
.
dist_spec
,
v
.
compute_spec
)
v
.
to_replicate_
()
model_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
model_state
[
'model'
]
=
collections
.
OrderedDict
([(
k
.
split
(
'.'
,
1
)[
1
],
v
)
for
k
,
v
in
model_state
[
'model'
].
items
()])
model
.
load_state_dict
(
model_state
[
'model'
])
model
.
load_state_dict
(
model_state
[
'model'
])
optim_state
=
torch
.
load
(
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
rank
))
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
# reset tensors to original dist spec.
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
with
DistSpecManager
.
no_grad
():
if
'after_scheduler_type'
in
lr_scheduler_dict
:
for
k
,
v
in
model
.
named_parameters
():
after_scheduler_type
=
lr_scheduler_dict
.
pop
(
'after_scheduler_type'
)
if
isinstance
(
v
,
ColoTensor
):
after_scheduler_dict
=
lr_scheduler_dict
.
pop
(
'after_scheduler_dict'
)
v
.
set_tensor_spec
(
*
mapping
[
k
])
reload_scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
after_scheduler_type
)
filtered_dict
=
filter_dict
(
after_scheduler_dict
,
reload_scheduler
)
lr_scheduler_dict
[
'after_scheduler'
]
=
reload_scheduler
(
optimizer
,
**
filtered_dict
,
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
colossalai/utils/model/colo_init_context.py
View file @
c92f84fc
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
ProcessGroup
,
ReplicaSpec
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
ColoLinear
,
ColoEmbedding
ColoLinear
,
ColoEmbedding
from
copy
import
copy
from
torch
import
nn
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
from
typing
import
Iterator
,
Tuple
,
Union
from
functools
import
partialmethod
# find named_params includes replica
# find named_params includes replica
...
@@ -34,47 +31,6 @@ def ColoModulize(module):
...
@@ -34,47 +31,6 @@ def ColoModulize(module):
module
.
_colo_visited
=
True
module
.
_colo_visited
=
True
def
colo_state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
,
state_dict_func
=
None
):
# build param to spec mapping
mapping1
=
dict
()
mapping2
=
dict
()
mapping3
=
dict
()
# gather all params
has_dist_parameter
=
False
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
has_dist_parameter
=
True
mapping1
[
id
(
param
)]
=
copy
(
param
.
dist_spec
)
mapping2
[
id
(
param
)]
=
copy
(
param
.
compute_spec
)
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
if
param
.
get_process_group
()
is
None
:
param
.
process_group
=
ProcessGroup
()
param
.
set_dist_spec
(
distspec
.
replicate
())
mapping3
[
id
(
param
)]
=
param
.
get_process_group
()
param
.
process_group
=
None
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert
not
(
keep_vars
and
has_dist_parameter
),
'keep_vars cannot be True when there are distributed ColoParameters.'
ret
=
state_dict_func
(
self
,
destination
,
prefix
,
keep_vars
)
# recover
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
param_id
=
id
(
param
)
if
param_id
in
mapping1
:
dist_spec
=
mapping1
[
id
(
param
)]
compute_spec
=
mapping2
[
id
(
param
)]
param
.
process_group
=
mapping3
[
id
(
param
)]
param
.
set_tensor_spec
(
dist_spec
,
compute_spec
)
return
ret
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
lazy_memory_allocate
:
bool
=
False
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
)):
def
__init__
(
self
,
lazy_memory_allocate
:
bool
=
False
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
)):
...
@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module
(
torch
.
nn
.
Embedding
,
ColoEmbedding
())
register_colo_module
(
torch
.
nn
.
Embedding
,
ColoEmbedding
())
def
_pre_context_exec
(
self
):
def
_pre_context_exec
(
self
):
self
.
state_dict_func
=
nn
.
Module
.
state_dict
pass
nn
.
Module
.
state_dict
=
partialmethod
(
colo_state_dict
,
state_dict_func
=
self
.
state_dict_func
)
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
"""
"""
...
...
tests/test_tensor/test_tensor.py
View file @
c92f84fc
...
@@ -122,6 +122,19 @@ def _run_redistributed(world_size):
...
@@ -122,6 +122,19 @@ def _run_redistributed(world_size):
assert
t1
.
is_replicate
()
assert
t1
.
is_replicate
()
def
_run_set_tensor_spec
(
world_size
):
if
world_size
!=
4
:
return
pg
=
ProcessGroup
(
tp_degree
=
2
,
dp_degree
=
2
)
spec1
=
ColoTensorSpec
(
pg
)
t1
=
ColoTensor
.
from_torch_tensor
(
torch
.
randn
(
2
,
3
,
4
),
spec1
)
dist_spec2
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
None
)
assert
t1
.
is_replicate
()
t1
.
set_dist_spec
(
*
dist_spec2
)
assert
t1
.
is_shard_1dcol
()
def
run_dist_tests
(
rank
,
world_size
,
port
):
def
run_dist_tests
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_shard_init
(
world_size
)
...
@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
...
@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
_run_operand
(
world_size
)
_run_operand
(
world_size
)
_run_wrapped_tensor_func
()
_run_wrapped_tensor_func
()
_run_redistributed
(
world_size
)
_run_redistributed
(
world_size
)
_run_set_tensor_spec
(
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
c92f84fc
...
@@ -3,7 +3,6 @@ import os, shutil
...
@@ -3,7 +3,6 @@ import os, shutil
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pytest
import
pytest
import
copy
from
functools
import
partial
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -104,7 +103,7 @@ def remove(path):
...
@@ -104,7 +103,7 @@ def remove(path):
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
):
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
num_epoch
=
5
num_epoch
=
5
warmup_epoch
=
2
warmup_epoch
=
2
...
@@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
...
@@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
feature
=
32
feature
=
32
category
=
16
category
=
16
train_dataloader
=
DummyDataLoader
(
batch
,
category
,
feature
,
length
=
16
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
feature
,
category
)
model
=
MLP
(
feature
,
category
)
with
ColoInitContext
(
device
=
get_current_device
()):
model_reload
=
MLP
(
feature
,
category
)
model_reload
=
MLP
(
feature
,
category
)
model_ref
=
MLP
(
feature
,
category
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_ref
=
model_ref
.
cuda
()
if
use_ddp
:
if
use_ddp
:
model
=
ColoDDP
(
model
,
pg
)
model
=
ColoDDP
(
model
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_reload
=
ColoDDP
(
model_reload
,
pg
)
model_ref
=
ColoDDP
(
model_ref
,
pg
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model_ref
,
pg
)
if
use_mp_reload
:
init_spec_func
(
model_reload
,
pg
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
optimizer_reload
=
torch
.
optim
.
Adam
(
model_reload
.
parameters
(),
lr
=
0.001
,
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
eps
=
1e-08
,
weight_decay
=
0
)
weight_decay
=
0
)
optimizer_ref
=
torch
.
optim
.
Adam
(
model_ref
.
parameters
(),
lr
=
0.001
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-08
,
weight_decay
=
0
)
lr_scheduler
=
None
lr_scheduler
=
None
if
test_scheduler
==
'colossalai_cosine_warmup'
:
if
test_scheduler
==
'colossalai_cosine_warmup'
:
...
@@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
...
@@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
else
:
else
:
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
raise
TypeError
(
f
"
{
test_scheduler
}
is invalid"
)
for
epoch
in
range
(
0
,
num_epoch
):
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
optimizer
,
lr_scheduler
)
if
epoch
<=
test_epoch
:
dist
.
barrier
()
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
if
use_ddp
:
model
.
zero_grad
()
# Since model is sharded, we merge them before param checking.
else
:
for
p
in
model
.
parameters
():
optimizer
.
zero_grad
()
p
.
to_replicate_
()
logits
=
model
(
image_dict
[
'pixel_values'
])
loss
=
criterion
(
logits
,
image_dict
[
'label'
])
for
p
in
model_reload
.
parameters
():
if
use_ddp
:
p
.
to_replicate_
()
model
.
backward
(
loss
)
else
:
check_param_equal
(
model
,
model_reload
)
loss
.
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
epoch
==
test_epoch
:
for
ref_p
,
p
in
zip
(
model_ref
.
parameters
(),
model
.
parameters
()):
ref_p
.
data
.
copy_
(
p
)
optimizer_ref
=
copy
.
deepcopy
(
optimizer
)
check_param_equal
(
model
,
model_ref
)
save_checkpoint
(
'./checkpoint'
,
epoch
,
model
,
optimizer
,
lr_scheduler
)
dist
.
barrier
()
else
:
if
epoch
==
test_epoch
+
1
:
load_checkpoint
(
'./checkpoint'
,
test_epoch
,
dist
.
get_rank
(),
model_reload
,
optimizer_reload
,
lr_scheduler_reload
)
init_spec_func
(
model_reload
,
pg
)
for
i
,
image_dict
in
enumerate
(
train_dataloader
):
if
use_ddp
:
model_ref
.
zero_grad
()
model_reload
.
zero_grad
()
else
:
optimizer_ref
.
zero_grad
()
optimizer_reload
.
zero_grad
()
logits_ref
=
model_ref
(
image_dict
[
'pixel_values'
])
logits_reload
=
model_reload
(
image_dict
[
'pixel_values'
])
loss_ref
=
criterion
(
logits_ref
,
image_dict
[
'label'
])
loss_reload
=
criterion
(
logits_reload
,
image_dict
[
'label'
])
if
use_ddp
:
model_ref
.
backward
(
loss_ref
)
model_reload
.
backward
(
loss_reload
)
else
:
loss_ref
.
backward
()
loss_reload
.
backward
()
optimizer_ref
.
step
()
optimizer_reload
.
step
()
lr_scheduler
.
step
()
check_param_equal
(
model_ref
,
model_reload
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
test_epoch
,
test_scheduler
):
if
use_ddp
and
world_size
==
1
:
if
use_ddp
and
world_size
==
1
:
return
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
use_ddp
,
test_epoch
=
test_epoch
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'
test_epoch'
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
'
use_mp_reload'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
@
pytest
.
mark
.
parametrize
(
'test_scheduler'
,
[
'colossalai_cosine_warmup'
,
'torch_cosine'
,
'torch_lambda'
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
test_epoch
,
test_scheduler
):
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
if
not
os
.
path
.
isdir
(
'./checkpoint'
):
os
.
mkdir
(
'./checkpoint'
)
os
.
mkdir
(
'./checkpoint'
)
run_func
=
partial
(
run_dist
,
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
world_size
=
world_size
,
port
=
free_port
(),
port
=
free_port
(),
use_ddp
=
use_ddp
,
use_ddp
=
use_ddp
,
test_epoch
=
test_epoch
,
use_mp_reload
=
use_mp_reload
,
test_scheduler
=
test_scheduler
)
test_scheduler
=
test_scheduler
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
remove
(
'./checkpoint'
)
remove
(
'./checkpoint'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_checkpoint
(
4
,
True
,
1
,
"colossalai_cosine_warmup
"
)
test_checkpoint
(
2
,
True
,
False
,
"torch_cosine
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment