Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
99870726
Unverified
Commit
99870726
authored
Nov 08, 2022
by
ver217
Committed by
GitHub
Nov 08, 2022
Browse files
[CheckpointIO] a uniform checkpoint I/O module (#1689)
parent
629172b3
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2111 additions
and
0 deletions
+2111
-0
colossalai/utils/checkpoint_io/__init__.py
colossalai/utils/checkpoint_io/__init__.py
+2
-0
colossalai/utils/checkpoint_io/backend.py
colossalai/utils/checkpoint_io/backend.py
+74
-0
colossalai/utils/checkpoint_io/constant.py
colossalai/utils/checkpoint_io/constant.py
+9
-0
colossalai/utils/checkpoint_io/convertor.py
colossalai/utils/checkpoint_io/convertor.py
+227
-0
colossalai/utils/checkpoint_io/distributed.py
colossalai/utils/checkpoint_io/distributed.py
+127
-0
colossalai/utils/checkpoint_io/io.py
colossalai/utils/checkpoint_io/io.py
+170
-0
colossalai/utils/checkpoint_io/meta.py
colossalai/utils/checkpoint_io/meta.py
+81
-0
colossalai/utils/checkpoint_io/reader.py
colossalai/utils/checkpoint_io/reader.py
+131
-0
colossalai/utils/checkpoint_io/utils.py
colossalai/utils/checkpoint_io/utils.py
+223
-0
colossalai/utils/checkpoint_io/writer.py
colossalai/utils/checkpoint_io/writer.py
+98
-0
tests/test_utils/test_checkpoint_io/test_build_checkpoints.py
...s/test_utils/test_checkpoint_io/test_build_checkpoints.py
+120
-0
tests/test_utils/test_checkpoint_io/test_load.py
tests/test_utils/test_checkpoint_io/test_load.py
+188
-0
tests/test_utils/test_checkpoint_io/test_merge.py
tests/test_utils/test_checkpoint_io/test_merge.py
+127
-0
tests/test_utils/test_checkpoint_io/test_merge_param.py
tests/test_utils/test_checkpoint_io/test_merge_param.py
+101
-0
tests/test_utils/test_checkpoint_io/test_redist.py
tests/test_utils/test_checkpoint_io/test_redist.py
+149
-0
tests/test_utils/test_checkpoint_io/test_save.py
tests/test_utils/test_checkpoint_io/test_save.py
+147
-0
tests/test_utils/test_checkpoint_io/test_unmerge_param.py
tests/test_utils/test_checkpoint_io/test_unmerge_param.py
+137
-0
No files found.
colossalai/utils/checkpoint_io/__init__.py
0 → 100644
View file @
99870726
from
.io
import
load
,
merge
,
redist
,
save
from
.meta
import
(
ParamDistMeta
,
ParamRedistMeta
,
PipelineRedistMeta
,
RankRedistMeta
,
RedistMeta
)
colossalai/utils/checkpoint_io/backend.py
0 → 100644
View file @
99870726
import
shutil
import
tempfile
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Type
from
.reader
import
CheckpointReader
,
DiskCheckpointReader
from
.writer
import
CheckpointWriter
,
DiskCheckpointWriter
_backends
:
Dict
[
str
,
Type
[
'CheckpointIOBackend'
]]
=
{}
def
register
(
name
:
str
):
assert
name
not
in
_backends
,
f
'"
{
name
}
" is registered'
def
wrapper
(
cls
):
_backends
[
name
]
=
cls
return
cls
return
wrapper
def
get_backend
(
name
:
str
)
->
'CheckpointIOBackend'
:
assert
name
in
_backends
,
f
'Unsupported backend "
{
name
}
"'
return
_backends
[
name
]()
class
CheckpointIOBackend
(
ABC
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
temps
:
List
[
str
]
=
[]
@
abstractmethod
def
get_writer
(
self
,
base_name
:
str
,
overwrite
:
bool
=
False
,
rank
:
int
=
0
,
world_size
:
int
=
1
)
->
CheckpointWriter
:
pass
@
abstractmethod
def
get_reader
(
self
,
base_name
:
str
)
->
CheckpointReader
:
pass
@
abstractmethod
def
get_temp
(
self
,
base_name
:
str
)
->
str
:
pass
@
abstractmethod
def
clean_temp
(
self
)
->
None
:
pass
@
register
(
'disk'
)
class
CheckpointDiskIO
(
CheckpointIOBackend
):
def
get_writer
(
self
,
base_name
:
str
,
overwrite
:
bool
=
False
,
rank
:
int
=
0
,
world_size
:
int
=
1
)
->
CheckpointWriter
:
return
DiskCheckpointWriter
(
base_name
,
overwrite
,
rank
=
rank
,
world_size
=
world_size
)
def
get_reader
(
self
,
base_name
:
str
)
->
CheckpointReader
:
return
DiskCheckpointReader
(
base_name
)
def
get_temp
(
self
,
base_name
:
str
)
->
str
:
temp_dir_name
=
tempfile
.
mkdtemp
(
dir
=
base_name
)
self
.
temps
.
append
(
temp_dir_name
)
return
temp_dir_name
def
clean_temp
(
self
)
->
None
:
for
temp_dir_name
in
self
.
temps
:
shutil
.
rmtree
(
temp_dir_name
)
colossalai/utils/checkpoint_io/constant.py
0 → 100644
View file @
99870726
import
re
GLOBAL_META_FILE_NAME
=
'global_meta.bin'
MODEL_CKPT_FILE_NAME
=
'model.bin'
OPTIM_CKPT_FILE_NAME
=
'optim.bin'
META_CKPT_FILE_NAME
=
'meta.bin'
OTHER_CKPT_FILE_NAME
=
'other.bin'
CKPT_PAT
=
re
.
compile
(
r
'global_meta|model|optim|meta|other'
)
colossalai/utils/checkpoint_io/convertor.py
0 → 100644
View file @
99870726
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
torch
import
Tensor
from
.distributed
import
merge_param
,
unmerge_param
from
.meta
import
ParamDistMeta
,
RedistMeta
from
.utils
import
(
ModelCheckpointSharder
,
OptimizerCheckpointSharder
,
run_if_not_none
)
class
CheckpointConvertor
(
ABC
):
@
abstractmethod
def
append
(
self
,
shard_dict
:
Dict
[
int
,
dict
],
dist_meta_list
:
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]])
->
None
:
pass
@
abstractmethod
def
complete
(
self
)
->
None
:
pass
class
ModelCheckpointConvertor
(
CheckpointConvertor
):
def
__init__
(
self
,
param_count
:
Dict
[
str
,
int
])
->
None
:
super
().
__init__
()
self
.
param_count
=
param_count
self
.
buffer
:
Dict
[
str
,
Dict
[
int
,
Tensor
]]
=
defaultdict
(
dict
)
@
abstractmethod
def
convert_tensors
(
self
,
key
:
str
,
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
pass
def
append
(
self
,
shard_dict
:
Dict
[
int
,
dict
],
dist_meta_list
:
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]])
->
None
:
for
rank
,
state_dict
in
shard_dict
.
items
():
for
k
,
tensor
in
state_dict
.
items
():
self
.
buffer
[
k
][
rank
]
=
tensor
converted_keys
=
set
()
for
k
,
rank_dict
in
self
.
buffer
.
items
():
if
len
(
rank_dict
)
==
self
.
param_count
[
k
]:
tensors
=
[]
dist_metas
=
[]
for
rank
,
tensor
in
rank_dict
.
items
():
tensors
.
append
(
tensor
)
if
dist_meta_list
[
rank
]
is
not
None
:
dist_metas
.
append
(
dist_meta_list
[
rank
][
k
])
self
.
convert_tensors
(
k
,
tensors
,
dist_metas
)
converted_keys
.
add
(
k
)
for
k
in
converted_keys
:
del
self
.
buffer
[
k
]
def
complete
(
self
)
->
None
:
assert
len
(
self
.
buffer
)
==
0
class
ModelCheckpointMerger
(
ModelCheckpointConvertor
):
def
__init__
(
self
,
max_shard_size
:
int
,
save_fn
:
Callable
[[
dict
],
Any
],
param_count
:
Dict
[
str
,
int
])
->
None
:
super
().
__init__
(
param_count
)
self
.
sharder
=
ModelCheckpointSharder
(
max_shard_size
)
self
.
save_fn
=
save_fn
def
convert_tensors
(
self
,
key
:
str
,
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
assert
len
(
dist_metas
)
==
len
(
tensors
)
tensor
=
merge_param
(
tensors
,
dist_metas
)
shard
=
self
.
sharder
.
append
(
key
,
tensor
)
run_if_not_none
(
self
.
save_fn
,
shard
)
def
complete
(
self
)
->
None
:
super
().
complete
()
run_if_not_none
(
self
.
save_fn
,
self
.
sharder
.
complete
())
class
ModelCheckpointRedistor
(
ModelCheckpointConvertor
):
def
__init__
(
self
,
max_shard_size
:
int
,
save_fns
:
List
[
Callable
[[
dict
],
Any
]],
param_count
:
Dict
[
str
,
int
],
redist_meta
:
RedistMeta
)
->
None
:
super
().
__init__
(
param_count
)
self
.
save_fns
=
save_fns
self
.
redist_meta
=
redist_meta
nprocs
=
len
(
save_fns
)
self
.
sharders
=
[
ModelCheckpointSharder
(
max_shard_size
)
for
_
in
range
(
nprocs
)]
self
.
rank_map
=
defaultdict
(
lambda
:
defaultdict
(
lambda
:
defaultdict
(
list
)))
for
k
,
rank_meta
in
redist_meta
.
rank_meta
.
items
():
for
rank
,
rank_info
in
rank_meta
.
items
():
self
.
rank_map
[
k
][
rank_info
.
tp_rank
][
rank_info
.
dp_rank
].
append
(
rank
)
def
convert_tensors
(
self
,
key
:
str
,
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
if
len
(
dist_metas
)
==
0
:
# already global
tensor
=
tensors
[
0
]
else
:
assert
len
(
dist_metas
)
==
len
(
tensors
)
tensor
=
merge_param
(
tensors
,
dist_metas
)
for
tp_rank
,
tensor_list
in
enumerate
(
unmerge_param
(
tensor
,
self
.
redist_meta
.
param_meta
[
key
])):
for
dp_rank
,
t
in
enumerate
(
tensor_list
):
for
rank
in
self
.
rank_map
[
key
][
tp_rank
][
dp_rank
]:
shard
=
self
.
sharders
[
rank
].
append
(
key
,
t
)
run_if_not_none
(
self
.
save_fns
[
rank
],
shard
)
def
complete
(
self
)
->
None
:
super
().
complete
()
for
rank
,
save_fn
in
enumerate
(
self
.
save_fns
):
run_if_not_none
(
save_fn
,
self
.
sharders
[
rank
].
complete
())
class
OptimizerCheckpointConvertor
(
CheckpointConvertor
):
def
__init__
(
self
,
param_count
:
Dict
[
str
,
int
],
param_to_os
:
Optional
[
Dict
[
str
,
int
]],
paired_os
:
Optional
[
Dict
[
int
,
dict
]])
->
None
:
super
().
__init__
()
self
.
param_count
=
param_count
self
.
param_to_os
=
param_to_os
self
.
paired_os
=
paired_os
self
.
buffer
:
Dict
[
int
,
Dict
[
int
,
dict
]]
=
defaultdict
(
dict
)
self
.
os_to_param
=
{
v
:
k
for
k
,
v
in
param_to_os
.
items
()}
@
abstractmethod
def
setup
(
self
,
param_groups
:
dict
)
->
None
:
pass
@
abstractmethod
def
convert_states
(
self
,
idx
:
int
,
states
:
List
[
dict
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
pass
def
append
(
self
,
shard_dict
:
Dict
[
int
,
dict
],
dist_meta_list
:
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]])
->
None
:
for
rank
,
state_dict
in
shard_dict
.
items
():
self
.
setup
(
state_dict
[
'param_groups'
])
for
idx
,
state
in
state_dict
[
'state'
].
items
():
self
.
buffer
[
idx
][
rank
]
=
state
converted_indices
=
set
()
for
idx
,
rank_dict
in
self
.
buffer
.
items
():
if
len
(
rank_dict
)
==
self
.
param_count
[
self
.
os_to_param
[
idx
]]:
states
=
[]
dist_metas
=
[]
for
rank
,
state
in
rank_dict
.
items
():
states
.
append
(
state
)
if
dist_meta_list
[
rank
]
is
not
None
:
dist_metas
.
append
(
dist_meta_list
[
rank
][
self
.
os_to_param
[
idx
]])
self
.
convert_states
(
idx
,
states
,
dist_metas
)
converted_indices
.
add
(
idx
)
for
idx
in
converted_indices
:
del
self
.
buffer
[
idx
]
def
complete
(
self
)
->
None
:
assert
len
(
self
.
buffer
)
==
0
class
OptimizerCheckpointMerger
(
OptimizerCheckpointConvertor
):
def
__init__
(
self
,
max_shard_size
:
int
,
save_fn
:
Callable
[[
dict
],
Any
],
param_count
:
Dict
[
str
,
int
],
param_to_os
:
Optional
[
Dict
[
str
,
int
]],
paired_os
:
Optional
[
Dict
[
int
,
dict
]])
->
None
:
super
().
__init__
(
param_count
,
param_to_os
,
paired_os
)
self
.
max_shard_size
=
max_shard_size
self
.
save_fn
=
save_fn
self
.
sharder
=
None
def
setup
(
self
,
param_groups
:
dict
)
->
None
:
if
self
.
sharder
is
None
:
self
.
sharder
=
OptimizerCheckpointSharder
(
self
.
max_shard_size
,
param_groups
)
def
convert_states
(
self
,
idx
:
int
,
states
:
List
[
dict
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
assert
len
(
dist_metas
)
==
len
(
states
)
new_state
=
{}
for
state_key
,
state_tensor
in
states
[
0
].
items
():
if
self
.
paired_os
[
idx
][
state_key
]:
new_state
[
state_key
]
=
merge_param
([
state
[
state_key
]
for
state
in
states
],
dist_metas
)
else
:
new_state
[
state_key
]
=
state_tensor
shard
=
self
.
sharder
.
append
(
idx
,
new_state
)
run_if_not_none
(
self
.
save_fn
,
shard
)
def
complete
(
self
)
->
None
:
super
().
complete
()
run_if_not_none
(
self
.
save_fn
,
self
.
sharder
.
complete
())
class
OptimizerCheckpointRedistor
(
OptimizerCheckpointConvertor
):
def
__init__
(
self
,
max_shard_size
:
int
,
save_fns
:
List
[
Callable
[[
dict
],
Any
]],
param_count
:
Dict
[
str
,
int
],
param_to_os
:
Optional
[
Dict
[
str
,
int
]],
paired_os
:
Optional
[
Dict
[
int
,
dict
]],
redist_meta
:
RedistMeta
)
->
None
:
super
().
__init__
(
param_count
,
param_to_os
,
paired_os
)
self
.
max_shard_size
=
max_shard_size
self
.
save_fns
=
save_fns
self
.
redist_meta
=
redist_meta
self
.
sharders
:
List
[
OptimizerCheckpointSharder
]
=
[]
self
.
rank_map
=
defaultdict
(
lambda
:
defaultdict
(
lambda
:
defaultdict
(
list
)))
for
k
,
rank_meta
in
redist_meta
.
rank_meta
.
items
():
for
rank
,
rank_info
in
rank_meta
.
items
():
self
.
rank_map
[
k
][
rank_info
.
tp_rank
][
rank_info
.
dp_rank
].
append
(
rank
)
def
setup
(
self
,
param_groups
:
dict
)
->
None
:
if
len
(
self
.
sharders
)
==
0
:
nprocs
=
len
(
self
.
save_fns
)
for
_
in
range
(
nprocs
):
self
.
sharders
.
append
(
OptimizerCheckpointSharder
(
self
.
max_shard_size
,
param_groups
))
def
convert_states
(
self
,
idx
:
int
,
states
:
List
[
dict
],
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
need_merge
:
bool
=
True
if
len
(
dist_metas
)
==
0
:
need_merge
=
False
else
:
assert
len
(
dist_metas
)
==
len
(
states
)
new_states
=
[{}
for
_
in
range
(
len
(
self
.
save_fns
))]
for
state_key
,
state_tensor
in
states
[
0
].
items
():
if
self
.
paired_os
[
idx
][
state_key
]:
if
need_merge
:
tensor
=
merge_param
([
state
[
state_key
]
for
state
in
states
],
dist_metas
)
else
:
tensor
=
state_tensor
for
tp_rank
,
tensor_list
in
enumerate
(
unmerge_param
(
tensor
,
self
.
redist_meta
.
param_meta
[
self
.
os_to_param
[
idx
]])):
for
dp_rank
,
t
in
enumerate
(
tensor_list
):
for
rank
in
self
.
rank_map
[
self
.
os_to_param
[
idx
]][
tp_rank
][
dp_rank
]:
new_states
[
rank
][
state_key
]
=
t
else
:
for
new_state
in
new_states
:
new_state
[
state_key
]
=
state_tensor
for
rank
,
new_state
in
enumerate
(
new_states
):
shard
=
self
.
sharders
[
rank
].
append
(
idx
,
new_state
)
run_if_not_none
(
self
.
save_fns
[
rank
],
shard
)
def
complete
(
self
)
->
None
:
super
().
complete
()
for
rank
,
save_fn
in
enumerate
(
self
.
save_fns
):
run_if_not_none
(
save_fn
,
self
.
sharders
[
rank
].
complete
())
colossalai/utils/checkpoint_io/distributed.py
0 → 100644
View file @
99870726
import
torch
from
numpy
import
prod
from
torch
import
Tensor
from
typing
import
List
,
Optional
,
Tuple
from
collections
import
defaultdict
from
.meta
import
ParamDistMeta
,
ParamRedistMeta
def
unflatten_zero_param
(
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
Tensor
:
assert
len
(
tensors
)
>
0
and
len
(
dist_metas
)
>
0
and
len
(
tensors
)
==
len
(
dist_metas
)
for
dist_meta
in
dist_metas
[
1
:]:
assert
dist_meta
.
zero_meta
==
dist_metas
[
0
].
zero_meta
,
'Expect all params have the same zero meta.'
if
not
dist_metas
[
0
].
used_zero
:
# tensors are replicate
return
tensors
[
0
]
numel
=
dist_metas
[
0
].
zero_numel
orig_shape
=
dist_metas
[
0
].
zero_orig_shape
tensors
=
[
t
[
1
]
for
t
in
sorted
(
zip
(
dist_metas
,
tensors
),
key
=
lambda
tp
:
tp
[
0
].
dp_rank
)]
assert
numel
==
sum
(
t
.
numel
()
for
t
in
tensors
),
'Expect numel of all params is equal to zero_numel.'
return
torch
.
cat
(
tensors
).
reshape
(
orig_shape
)
def
gather_tp_param
(
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
Tensor
:
assert
len
(
tensors
)
>
0
and
len
(
dist_metas
)
>
0
and
len
(
tensors
)
==
len
(
dist_metas
)
for
dist_meta
in
dist_metas
[
1
:]:
assert
dist_meta
.
tp_meta
==
dist_metas
[
0
].
tp_meta
,
'Expect all params have the same tp meta.'
for
t
in
tensors
[
1
:]:
assert
t
.
shape
==
tensors
[
0
].
shape
,
'Expect all params have the same shape.'
if
not
dist_metas
[
0
].
used_tp
:
# tensors are replicate
return
tensors
[
0
]
total_parts
=
prod
(
dist_meta
.
tp_num_parts
)
assert
dist_meta
.
tp_world_size
==
total_parts
,
\
f
'Expect prod(tp_num_parts) == tp_world_size, got
{
total_parts
}
and
{
dist_meta
.
tp_world_size
}
.'
shard_info
=
sorted
(
zip
(
dist_meta
.
tp_shard_dims
,
dist_meta
.
tp_num_parts
),
key
=
lambda
t
:
t
[
0
],
reverse
=
True
)
for
dim
,
num_parts
in
shard_info
:
buffer
=
[]
for
start
in
range
(
0
,
len
(
tensors
),
num_parts
):
buffer
.
append
(
torch
.
cat
(
tensors
[
start
:
start
+
num_parts
],
dim
))
tensors
=
buffer
assert
len
(
tensors
)
==
1
return
tensors
[
0
]
def
validate_parallel_info
(
dist_metas
:
List
[
ParamDistMeta
])
->
None
:
assert
len
(
dist_metas
)
>
0
# check world size
for
dist_meta
in
dist_metas
[
1
:]:
assert
dist_meta
.
dp_world_size
==
dist_metas
[
0
].
dp_world_size
,
'Expect all dist meta have the same dp_world_size'
assert
dist_meta
.
tp_world_size
==
dist_metas
[
0
].
tp_world_size
,
'Expect all dist meta have the same tp_world_size'
def
deduplicate_params
(
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
Tuple
[
List
[
Tensor
],
List
[
ParamDistMeta
]]:
unique_dist_meta
=
[]
unique_idx
=
[]
for
i
,
dist_meta
in
enumerate
(
dist_metas
):
if
dist_meta
not
in
unique_dist_meta
:
unique_dist_meta
.
append
(
dist_meta
)
unique_idx
.
append
(
i
)
return
[
tensors
[
i
]
for
i
in
unique_idx
],
[
dist_metas
[
i
]
for
i
in
unique_idx
]
def
merge_param
(
tensors
:
List
[
Tensor
],
dist_metas
:
List
[
ParamDistMeta
])
->
Tensor
:
assert
len
(
tensors
)
>
0
and
len
(
dist_metas
)
>
0
and
len
(
tensors
)
==
len
(
dist_metas
)
# validate parallel info
validate_parallel_info
(
dist_metas
)
tensors
,
dist_metas
=
deduplicate_params
(
tensors
,
dist_metas
)
unflattened_tensors
=
[]
# group zero params by tp rank
tensor_dict
=
defaultdict
(
list
)
dist_meta_dict
=
defaultdict
(
list
)
for
t
,
dist_meta
in
zip
(
tensors
,
dist_metas
):
tensor_dict
[
dist_meta
.
tp_rank
].
append
(
t
)
dist_meta_dict
[
dist_meta
.
tp_rank
].
append
(
dist_meta
)
assert
len
(
tensor_dict
)
==
dist_metas
[
0
].
tp_world_size
,
f
'Expect
{
dist_metas
[
0
].
tp_world_size
}
ranks, got
{
len
(
tensor_dict
)
}
'
for
tp_rank
in
tensor_dict
.
keys
():
unflattened_tensors
.
append
(
unflatten_zero_param
(
tensor_dict
[
tp_rank
],
dist_meta_dict
[
tp_rank
]))
return
gather_tp_param
(
unflattened_tensors
,
[
dist_meta_list
[
0
]
for
dist_meta_list
in
dist_meta_dict
.
values
()])
def
split_tp_param
(
tensor
:
Tensor
,
redist_meta
:
ParamRedistMeta
)
->
List
[
Tensor
]:
if
not
redist_meta
.
used_tp
:
assert
redist_meta
.
tp_world_size
==
1
,
'Expect tp_world_size == 1, when no tp meta provided.'
return
[
tensor
]
total_parts
=
prod
(
redist_meta
.
tp_num_parts
)
assert
redist_meta
.
tp_world_size
==
total_parts
,
f
'Expect prod(tp_num_parts) == tp_world_size, got
{
total_parts
}
and
{
redist_meta
.
tp_world_size
}
.'
shard_info
=
sorted
(
zip
(
redist_meta
.
tp_shard_dims
,
redist_meta
.
tp_num_parts
),
key
=
lambda
t
:
t
[
0
])
tensors
=
[
tensor
]
for
dim
,
num_parts
in
shard_info
:
buffer
=
[]
for
t
in
tensors
:
assert
t
.
size
(
dim
)
%
num_parts
==
0
,
\
f
'Expect dim
{
dim
}
of tensor(
{
tensor
.
shape
}
) is divisible by
{
num_parts
}
.'
chunks
=
[
chunk
.
contiguous
()
for
chunk
in
t
.
chunk
(
num_parts
,
dim
)]
buffer
.
extend
(
chunks
)
tensors
=
buffer
assert
len
(
tensors
)
==
redist_meta
.
tp_world_size
return
tensors
def
flatten_zero_param
(
tensor
:
Tensor
,
redist_meta
:
ParamRedistMeta
)
->
List
[
Tensor
]:
if
not
redist_meta
.
used_zero
:
return
[
tensor
]
*
redist_meta
.
dp_world_size
tensors
:
List
[
Optional
[
Tensor
]]
=
[
torch
.
empty
(
0
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
redist_meta
.
zero_start_dp_rank
)
]
offsets
=
redist_meta
.
zero_offsets
+
[
tensor
.
numel
()]
for
i
,
offset
in
enumerate
(
offsets
[:
-
1
]):
end
=
offsets
[
i
+
1
]
tensors
.
append
(
tensor
.
view
(
-
1
)[
offset
:
end
])
if
len
(
tensors
)
<
redist_meta
.
dp_world_size
:
tensors
.
extend
([
torch
.
empty
(
0
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
redist_meta
.
dp_world_size
-
len
(
tensors
))
])
assert
len
(
tensors
)
==
redist_meta
.
dp_world_size
return
tensors
def
unmerge_param
(
tensor
:
Tensor
,
redist_meta
:
ParamRedistMeta
)
->
List
[
List
[
Tensor
]]:
tensors
=
split_tp_param
(
tensor
,
redist_meta
)
tensors
=
[
flatten_zero_param
(
t
,
redist_meta
)
for
t
in
tensors
]
return
tensors
colossalai/utils/checkpoint_io/io.py
0 → 100644
View file @
99870726
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch.distributed
as
dist
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
.backend
import
get_backend
from
.convertor
import
(
CheckpointConvertor
,
ModelCheckpointMerger
,
ModelCheckpointRedistor
,
OptimizerCheckpointMerger
,
OptimizerCheckpointRedistor
)
from
.meta
import
ParamDistMeta
,
RedistMeta
from
.utils
import
build_checkpoints
,
optimizer_load_state_dict
def
save
(
path
:
str
,
model
:
Module
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
param_to_os
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
dist_meta
:
Optional
[
Dict
[
str
,
ParamDistMeta
]]
=
None
,
max_shard_size_gb
:
float
=
0.0
,
overwrite
:
bool
=
False
,
backend
:
str
=
'disk'
,
**
kwargs
:
Any
)
->
None
:
io_backend
=
get_backend
(
backend
)
if
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
if
world_size
==
1
:
# global doesn't need dist_meta
dist_meta
=
None
else
:
assert
dist_meta
is
not
None
max_shard_size
=
int
(
max_shard_size_gb
*
1024
**
3
)
model_checkpoints
,
optimizer_checkpoints
,
meta_checkpoint
=
build_checkpoints
(
max_shard_size
,
model
,
optimizer
,
param_to_os
,
dist_meta
)
writer
=
io_backend
.
get_writer
(
path
,
overwrite
,
rank
,
world_size
)
writer
.
save_others
(
kwargs
)
for
model_checkpoint
in
model_checkpoints
:
writer
.
save_model
(
model_checkpoint
)
for
optimizer_checkpoint
in
optimizer_checkpoints
:
writer
.
save_optimizer
(
optimizer_checkpoint
)
writer
.
save_meta
(
meta_checkpoint
)
def
merge
(
path
:
str
,
output_path
:
str
,
max_shard_size_gb
:
float
=
0.0
,
overwrite
:
bool
=
False
,
backend
:
str
=
'disk'
)
->
bool
:
io_backend
=
get_backend
(
backend
)
if
dist
.
is_initialized
()
and
dist
.
get_rank
()
!=
0
:
return
False
reader
=
io_backend
.
get_reader
(
path
)
if
len
(
reader
.
meta_list
)
==
1
:
# already global
warnings
.
warn
(
f
'Checkpoint at "
{
path
}
" is already global, nothing to do.'
)
return
False
dist_meta_list
,
param_count
,
param_to_os
,
paired_os
=
reader
.
load_meta
()
writer
=
io_backend
.
get_writer
(
output_path
,
overwrite
=
overwrite
)
writer
.
save_others
(
reader
.
load_others
())
max_shard_size
=
int
(
max_shard_size_gb
*
1024
**
3
)
_convert_shards
(
ModelCheckpointMerger
(
max_shard_size
,
writer
.
save_model
,
param_count
),
reader
.
load_models
(),
dist_meta_list
)
_convert_shards
(
OptimizerCheckpointMerger
(
max_shard_size
,
writer
.
save_optimizer
,
param_count
,
param_to_os
,
paired_os
),
reader
.
load_optimizers
(),
dist_meta_list
)
meta_checkpoint
=
{
'dist_meta'
:
None
,
'params'
:
list
(
param_count
.
keys
())}
if
param_to_os
is
not
None
:
meta_checkpoint
[
'param_to_os'
]
=
param_to_os
meta_checkpoint
[
'paired_os'
]
=
paired_os
writer
.
save_meta
(
meta_checkpoint
)
return
True
def
redist
(
path
:
str
,
output_path
:
str
,
redist_meta
:
RedistMeta
,
dist_metas
:
List
[
Dict
[
str
,
ParamDistMeta
]],
max_shard_size_gb
:
float
=
0.0
,
overwrite
:
bool
=
False
,
backend
:
str
=
'disk'
)
->
bool
:
io_backend
=
get_backend
(
backend
)
if
dist
.
is_initialized
()
and
dist
.
get_rank
()
!=
0
:
return
False
nprocs
=
len
(
dist_metas
)
reader
=
io_backend
.
get_reader
(
path
)
dist_meta_list
,
param_count
,
param_to_os
,
paired_os
=
reader
.
load_meta
()
do_redist
:
bool
=
False
if
len
(
dist_meta_list
)
==
nprocs
:
for
a
,
b
in
zip
(
dist_metas
,
dist_meta_list
):
if
a
!=
b
:
do_redist
=
True
break
else
:
do_redist
=
True
if
not
do_redist
:
warnings
.
warn
(
f
'Checkpoint at "
{
path
}
" is not required to redist, nothing to do.'
)
return
False
writers
=
[
io_backend
.
get_writer
(
output_path
,
overwrite
,
rank
,
nprocs
)
for
rank
in
range
(
nprocs
)]
writers
[
0
].
save_others
(
reader
.
load_others
())
max_shard_size
=
int
(
max_shard_size_gb
*
1024
**
3
)
_convert_shards
(
ModelCheckpointRedistor
(
max_shard_size
,
[
writer
.
save_model
for
writer
in
writers
],
param_count
,
redist_meta
),
reader
.
load_models
(),
dist_meta_list
)
_convert_shards
(
OptimizerCheckpointRedistor
(
max_shard_size
,
[
writer
.
save_optimizer
for
writer
in
writers
],
param_count
,
param_to_os
,
paired_os
,
redist_meta
),
reader
.
load_optimizers
(),
dist_meta_list
)
for
writer
,
dist_meta
in
zip
(
writers
,
dist_metas
):
meta_checkpoint
=
{
'dist_meta'
:
dist_meta
,
'params'
:
list
(
param_count
.
keys
())}
if
param_to_os
is
not
None
:
meta_checkpoint
[
'param_to_os'
]
=
param_to_os
meta_checkpoint
[
'paired_os'
]
=
paired_os
writer
.
save_meta
(
meta_checkpoint
)
return
True
def
_convert_shards
(
convertor
:
CheckpointConvertor
,
shard_generator
:
Generator
[
dict
,
None
,
None
],
dist_meta_list
:
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]])
->
None
:
for
shard_dict
in
shard_generator
:
convertor
.
append
(
shard_dict
,
dist_meta_list
)
convertor
.
complete
()
def
load
(
path
:
str
,
model
:
Module
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
redist_meta
:
Optional
[
RedistMeta
]
=
None
,
dist_metas
:
Optional
[
List
[
Dict
[
str
,
ParamDistMeta
]]]
=
None
,
max_shard_size_gb
:
float
=
0.0
,
backend
:
str
=
'disk'
)
->
dict
:
is_global
:
bool
=
not
dist
.
is_initialized
()
or
dist
.
get_world_size
()
==
1
rank
:
int
=
dist
.
get_rank
()
if
dist
.
is_initialized
()
else
0
is_main_process
:
bool
=
rank
==
0
# validate args
if
redist_meta
is
None
or
dist_metas
is
None
:
assert
is_global
io_backend
=
get_backend
(
backend
)
read_path
:
str
=
path
if
is_main_process
:
# pre-process checkpoints
temp_path
=
io_backend
.
get_temp
(
path
)
if
is_global
:
wrote
=
merge
(
path
,
temp_path
,
max_shard_size_gb
,
backend
=
backend
)
else
:
wrote
=
redist
(
path
,
temp_path
,
redist_meta
,
dist_metas
,
max_shard_size_gb
,
backend
=
backend
)
if
wrote
:
read_path
=
temp_path
if
not
is_global
:
bcast_list
=
[
read_path
]
if
is_main_process
else
[
None
]
dist
.
broadcast_object_list
(
bcast_list
)
read_path
=
bcast_list
[
0
]
reader
=
io_backend
.
get_reader
(
read_path
)
# load model
for
shard
in
reader
.
load_model
(
rank
):
model
.
load_state_dict
(
shard
,
strict
=
False
)
if
optimizer
is
not
None
:
for
shard
in
reader
.
load_optimizer
(
rank
):
# optimizer.load_state_dict(shard)
optimizer_load_state_dict
(
optimizer
,
shard
)
others_dict
=
reader
.
load_others
()
if
not
is_global
:
dist
.
barrier
()
# clean up temp
if
is_main_process
:
io_backend
.
clean_temp
()
return
others_dict
colossalai/utils/checkpoint_io/meta.py
0 → 100644
View file @
99870726
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Set
,
Dict
@
dataclass
class
ParamDistMeta
:
# parallel info
dp_rank
:
int
dp_world_size
:
int
tp_rank
:
int
tp_world_size
:
int
# tp info
tp_shard_dims
:
Optional
[
List
[
int
]]
=
None
tp_num_parts
:
Optional
[
List
[
int
]]
=
None
# zero info
zero_numel
:
Optional
[
int
]
=
None
zero_orig_shape
:
Optional
[
List
[
int
]]
=
None
@
property
def
used_tp
(
self
)
->
bool
:
return
self
.
tp_shard_dims
is
not
None
and
self
.
tp_num_parts
is
not
None
@
property
def
used_zero
(
self
)
->
bool
:
return
self
.
zero_numel
is
not
None
and
self
.
zero_orig_shape
is
not
None
@
property
def
parallel_meta
(
self
)
->
tuple
:
return
self
.
dp_rank
,
self
.
dp_world_size
,
self
.
tp_rank
,
self
.
tp_world_size
@
property
def
tp_meta
(
self
)
->
tuple
:
return
self
.
tp_shard_dims
,
self
.
tp_num_parts
@
property
def
zero_meta
(
self
)
->
tuple
:
return
self
.
zero_numel
,
self
.
zero_orig_shape
@
staticmethod
def
from_dict
(
d
:
dict
)
->
'ParamDistMeta'
:
return
ParamDistMeta
(
**
d
)
@
dataclass
class
ParamRedistMeta
:
# parallel info
dp_world_size
:
int
tp_world_size
:
int
# tp info
tp_shard_dims
:
Optional
[
List
[
int
]]
=
None
tp_num_parts
:
Optional
[
List
[
int
]]
=
None
# zero info
zero_start_dp_rank
:
Optional
[
int
]
=
None
zero_offsets
:
Optional
[
List
[
int
]]
=
None
@
property
def
used_tp
(
self
)
->
bool
:
return
self
.
tp_shard_dims
is
not
None
and
self
.
tp_num_parts
is
not
None
@
property
def
used_zero
(
self
)
->
bool
:
return
self
.
zero_start_dp_rank
is
not
None
and
self
.
zero_offsets
is
not
None
@
dataclass
class
RankRedistMeta
:
dp_rank
:
int
tp_rank
:
int
pp_rank
:
int
@
dataclass
class
PipelineRedistMeta
:
params
:
Set
[
str
]
@
dataclass
class
RedistMeta
:
rank_meta
:
Dict
[
str
,
Dict
[
int
,
RankRedistMeta
]]
pipeline_meta
:
List
[
PipelineRedistMeta
]
param_meta
:
Dict
[
str
,
ParamRedistMeta
]
colossalai/utils/checkpoint_io/reader.py
0 → 100644
View file @
99870726
import
os
from
abc
import
ABC
,
abstractmethod
from
collections
import
Counter
from
typing
import
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
from
.constant
import
GLOBAL_META_FILE_NAME
,
OTHER_CKPT_FILE_NAME
from
.meta
import
ParamDistMeta
from
.utils
import
is_duplicated_list
class
CheckpointReader
(
ABC
):
def
__init__
(
self
,
base_name
:
str
)
->
None
:
super
().
__init__
()
self
.
base_name
=
base_name
self
.
meta_list
=
[]
@
abstractmethod
def
read
(
self
,
name
:
str
)
->
dict
:
pass
@
abstractmethod
def
load_meta
(
self
)
->
Tuple
[
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]],
Dict
[
str
,
int
],
Optional
[
dict
],
Optional
[
dict
]]:
pass
@
abstractmethod
def
load_model
(
self
,
rank
:
int
)
->
Generator
[
dict
,
None
,
None
]:
pass
@
abstractmethod
def
load_models
(
self
)
->
Generator
[
Dict
[
int
,
dict
],
None
,
None
]:
pass
@
abstractmethod
def
load_optimizer
(
self
,
rank
:
int
)
->
Generator
[
dict
,
None
,
None
]:
pass
@
abstractmethod
def
load_optimizers
(
self
)
->
Generator
[
Dict
[
int
,
dict
],
None
,
None
]:
pass
@
abstractmethod
def
load_others
(
self
)
->
dict
:
pass
class
DiskCheckpointReader
(
CheckpointReader
):
def
__init__
(
self
,
base_name
:
str
)
->
None
:
super
().
__init__
(
base_name
)
assert
os
.
path
.
isdir
(
base_name
),
f
'"
{
base_name
}
" is not a directory'
global_meta
=
self
.
read
(
GLOBAL_META_FILE_NAME
)
for
meta_file_name
in
global_meta
[
'meta'
]:
meta
=
self
.
read
(
meta_file_name
)
if
meta
.
get
(
'dist_meta'
,
None
)
is
None
:
# only global checkpoint can have empty dist_meta
assert
len
(
global_meta
[
'meta'
])
==
1
self
.
meta_list
.
append
(
meta
)
def
read
(
self
,
name
:
str
)
->
dict
:
return
torch
.
load
(
os
.
path
.
join
(
self
.
base_name
,
name
))
def
load_meta
(
self
)
->
Tuple
[
List
[
Optional
[
Dict
[
str
,
ParamDistMeta
]]],
Dict
[
str
,
int
],
Optional
[
dict
],
Optional
[
dict
]]:
meta_infos
=
[(
meta
.
get
(
'dist_meta'
,
None
),
meta
[
'params'
],
meta
.
get
(
'param_to_os'
,
None
),
meta
.
get
(
'paired_os'
,
None
))
for
meta
in
self
.
meta_list
]
dist_meta_list
,
params_list
,
param_to_os_list
,
paired_os_list
=
zip
(
*
meta_infos
)
# reduce param_count
param_count
=
Counter
(
p
for
params
in
params_list
for
p
in
params
)
# validate param_to_os
assert
is_duplicated_list
(
param_to_os_list
)
assert
is_duplicated_list
(
paired_os_list
)
return
list
(
dist_meta_list
),
param_count
,
param_to_os_list
[
0
],
paired_os_list
[
0
]
def
_load_shard
(
self
,
shard_type
:
str
,
rank
:
int
)
->
Generator
[
dict
,
None
,
None
]:
meta
=
self
.
meta_list
[
rank
]
checkpoint_names
=
meta
.
get
(
shard_type
,
[])
for
name
in
checkpoint_names
:
yield
self
.
read
(
name
)
def
load_model
(
self
,
rank
:
int
)
->
Generator
[
dict
,
None
,
None
]:
return
self
.
_load_shard
(
'model'
,
rank
)
def
load_models
(
self
)
->
Generator
[
Dict
[
int
,
dict
],
None
,
None
]:
indices
=
[
0
]
*
len
(
self
.
meta_list
)
while
True
:
shards
=
{}
for
i
,
meta
in
enumerate
(
self
.
meta_list
):
model_checkpoint_names
=
meta
.
get
(
'model'
,
[])
if
indices
[
i
]
<
len
(
model_checkpoint_names
):
shards
[
i
]
=
self
.
read
(
model_checkpoint_names
[
indices
[
i
]])
indices
[
i
]
+=
1
if
len
(
shards
)
>
0
:
yield
shards
else
:
break
def
load_optimizer
(
self
,
rank
:
int
)
->
Generator
[
dict
,
None
,
None
]:
param_groups
=
None
for
shard
in
self
.
_load_shard
(
'optimizer'
,
rank
):
if
param_groups
is
None
:
param_groups
=
shard
[
'param_groups'
]
else
:
shard
[
'param_groups'
]
=
param_groups
yield
shard
def
load_optimizers
(
self
)
->
Generator
[
Dict
[
int
,
dict
],
None
,
None
]:
indices
=
[
0
]
*
len
(
self
.
meta_list
)
param_groups
=
[]
while
True
:
shards
=
{}
for
i
,
meta
in
enumerate
(
self
.
meta_list
):
optimizer_checkpoint_names
=
meta
.
get
(
'optimizer'
,
[])
if
indices
[
i
]
<
len
(
optimizer_checkpoint_names
):
shards
[
i
]
=
self
.
read
(
optimizer_checkpoint_names
[
indices
[
i
]])
if
indices
[
i
]
==
0
:
param_groups
.
append
(
shards
[
i
][
'param_groups'
])
else
:
shards
[
i
][
'param_groups'
]
=
param_groups
[
i
]
indices
[
i
]
+=
1
if
len
(
shards
)
>
0
:
yield
shards
else
:
break
def
load_others
(
self
)
->
dict
:
return
self
.
read
(
OTHER_CKPT_FILE_NAME
)
colossalai/utils/checkpoint_io/utils.py
0 → 100644
View file @
99870726
import
warnings
from
copy
import
deepcopy
from
itertools
import
chain
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
.meta
import
ParamDistMeta
def
run_if_not_none
(
fn
:
Callable
[[
Any
],
Any
],
arg
:
Any
)
->
Any
:
if
arg
is
not
None
:
return
fn
(
arg
)
def
get_param_to_os
(
model
:
Module
,
optimizer
:
Optimizer
)
->
Dict
[
str
,
int
]:
# ensure all params in optimizer are in model state dict
params_set
=
set
(
id
(
p
)
for
p
in
model
.
parameters
())
for
group
in
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
id
(
p
)
in
params_set
param_mappings
=
{}
start_index
=
0
def
get_group_mapping
(
group
):
nonlocal
start_index
param_mappings
.
update
(
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
group
[
'params'
],
start_index
)
if
id
(
p
)
not
in
param_mappings
})
start_index
+=
len
(
group
[
'params'
])
for
g
in
optimizer
.
param_groups
:
get_group_mapping
(
g
)
return
{
k
:
param_mappings
[
id
(
p
)]
for
k
,
p
in
model
.
named_parameters
()}
def
compute_optimizer_state_size
(
state
:
Dict
[
str
,
Any
])
->
int
:
size
=
0
for
v
in
state
.
values
():
if
isinstance
(
v
,
Tensor
):
size
+=
v
.
numel
()
*
v
.
element_size
()
return
size
class
ModelCheckpointSharder
:
def
__init__
(
self
,
max_shard_size
:
int
)
->
None
:
self
.
max_shard_size
=
max_shard_size
self
.
buffer
:
Dict
[
str
,
Tensor
]
=
{}
self
.
buffer_size
:
int
=
0
def
append
(
self
,
key
:
str
,
tensor
:
Tensor
)
->
Optional
[
dict
]:
retval
=
None
if
self
.
max_shard_size
>
0
and
self
.
buffer_size
>=
self
.
max_shard_size
:
retval
=
self
.
buffer
self
.
buffer
=
{}
self
.
buffer_size
=
0
self
.
buffer
[
key
]
=
tensor
self
.
buffer_size
+=
tensor
.
numel
()
*
tensor
.
element_size
()
return
retval
def
extend
(
self
,
state_dict
:
Dict
[
str
,
Tensor
])
->
List
[
dict
]:
shards
=
[]
for
key
,
tensor
in
state_dict
.
items
():
shard
=
self
.
append
(
key
,
tensor
)
run_if_not_none
(
shards
.
append
,
shard
)
return
shards
def
complete
(
self
)
->
Optional
[
dict
]:
return
self
.
buffer
if
len
(
self
.
buffer
)
>
0
else
None
class
OptimizerCheckpointSharder
:
def
__init__
(
self
,
max_shard_size
:
int
,
param_groups
:
dict
)
->
None
:
self
.
max_shard_size
=
max_shard_size
self
.
buffer
:
Dict
[
str
,
dict
]
=
{
'state'
:
{},
'param_groups'
:
param_groups
}
self
.
buffer_size
:
int
=
0
self
.
returned_first
:
bool
=
False
def
append
(
self
,
key
:
int
,
state
:
dict
)
->
Optional
[
dict
]:
retval
=
None
if
self
.
max_shard_size
>
0
and
self
.
buffer_size
>=
self
.
max_shard_size
:
retval
=
self
.
buffer
self
.
buffer
=
{
'state'
:
{}}
self
.
buffer_size
=
0
self
.
buffer
[
'state'
][
key
]
=
state
self
.
buffer_size
+=
compute_optimizer_state_size
(
state
)
return
retval
def
extend
(
self
,
state_dict
:
Dict
[
str
,
dict
])
->
List
[
dict
]:
shards
=
[]
for
key
,
state
in
state_dict
[
'state'
].
items
():
shard
=
self
.
append
(
key
,
state
)
run_if_not_none
(
shards
.
append
,
shard
)
return
shards
def
complete
(
self
)
->
Optional
[
dict
]:
return
self
.
buffer
if
len
(
self
.
buffer
[
'state'
])
>
0
else
None
def
shard_checkpoint
(
max_shard_size
:
int
,
model_state_dict
:
Dict
[
str
,
Tensor
],
optimizer_state_dict
:
Optional
[
dict
]
=
None
,
param_to_os
:
Optional
[
dict
]
=
None
)
->
Tuple
[
List
[
dict
],
List
[
dict
]]:
has_optimizer
:
bool
=
False
if
optimizer_state_dict
is
not
None
:
assert
param_to_os
is
not
None
os_to_param
=
{
v
:
k
for
k
,
v
in
param_to_os
.
items
()}
for
os_key
in
optimizer_state_dict
[
'state'
].
keys
():
assert
os_key
in
os_to_param
assert
os_to_param
[
os_key
]
in
model_state_dict
has_optimizer
=
True
model_sharder
=
ModelCheckpointSharder
(
max_shard_size
)
model_shards
=
model_sharder
.
extend
(
model_state_dict
)
run_if_not_none
(
model_shards
.
append
,
model_sharder
.
complete
())
if
not
has_optimizer
:
return
model_shards
,
[]
optimizer_sharder
=
OptimizerCheckpointSharder
(
max_shard_size
,
optimizer_state_dict
[
'param_groups'
])
optimizer_shards
=
optimizer_sharder
.
extend
(
optimizer_state_dict
)
run_if_not_none
(
optimizer_shards
.
append
,
optimizer_sharder
.
complete
())
return
model_shards
,
optimizer_shards
def
get_paired_os
(
model_state_dict
:
Dict
[
str
,
Tensor
],
optimizer_state_dict
:
dict
,
param_to_os
:
Dict
[
str
,
int
])
->
dict
:
os_to_param
=
{
v
:
k
for
k
,
v
in
param_to_os
.
items
()}
paired_os
=
{}
for
idx
,
state
in
optimizer_state_dict
[
'state'
].
items
():
paired_os
[
idx
]
=
{}
p
=
model_state_dict
[
os_to_param
[
idx
]]
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
Tensor
)
and
v
.
shape
==
p
.
shape
:
paired_os
[
idx
][
k
]
=
True
else
:
paired_os
[
idx
][
k
]
=
False
return
paired_os
def
build_checkpoints
(
max_size
:
int
,
model
:
Module
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
param_to_os
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
dist_meta
:
Optional
[
Dict
[
str
,
ParamDistMeta
]]
=
None
,
eliminate_replica
:
bool
=
False
)
->
Tuple
[
List
[
dict
],
List
[
dict
],
dict
]:
save_global
=
dist_meta
is
None
model_state_dict
=
model
.
state_dict
()
optimizer_state_dict
=
optimizer
.
state_dict
()
if
optimizer
else
None
meta
=
{
'dist_meta'
:
dist_meta
}
if
optimizer
:
param_to_os
=
param_to_os
or
get_param_to_os
(
model
,
optimizer
)
paired_os
=
get_paired_os
(
model_state_dict
,
optimizer_state_dict
,
param_to_os
)
meta
[
'param_to_os'
]
=
param_to_os
meta
[
'paired_os'
]
=
paired_os
if
not
save_global
and
eliminate_replica
:
# filter dp replicated params
model_state_dict
=
{
k
:
v
for
k
,
v
in
model_state_dict
.
items
()
if
dist_meta
[
k
].
used_zero
or
dist_meta
[
k
].
dp_rank
==
0
}
if
optimizer
:
optimizer_state_dict
[
'state'
]
=
{
param_to_os
[
k
]:
optimizer_state_dict
[
'state'
][
param_to_os
[
k
]]
for
k
in
model_state_dict
.
keys
()
if
dist_meta
[
k
].
used_zero
or
dist_meta
[
k
].
dp_rank
==
0
}
meta
[
'params'
]
=
list
(
model_state_dict
.
keys
())
if
len
(
model_state_dict
)
==
0
:
warnings
.
warn
(
'model state dict is empty, checkpoint is not saved'
)
return
[],
[],
meta
model_checkpoints
,
optimizer_checkpoints
=
shard_checkpoint
(
max_size
,
model_state_dict
,
optimizer_state_dict
,
param_to_os
)
return
model_checkpoints
,
optimizer_checkpoints
,
meta
def
is_duplicated_list
(
list_
:
List
[
Any
])
->
bool
:
if
len
(
list_
)
==
0
:
return
True
elem
=
list_
[
0
]
for
x
in
list_
[
1
:]:
if
x
!=
elem
:
return
False
return
True
def
copy_optimizer_state
(
src_state
:
dict
,
dest_state
:
dict
)
->
None
:
for
k
,
v
in
src_state
.
items
():
if
k
in
dest_state
:
old_v
=
dest_state
[
k
]
if
isinstance
(
old_v
,
Tensor
):
old_v
.
copy_
(
v
)
else
:
dest_state
[
k
]
=
v
def
optimizer_load_state_dict
(
optimizer
:
Optimizer
,
state_dict
:
dict
,
strict
:
bool
=
False
)
->
None
:
assert
optimizer
.
state_dict
()[
'param_groups'
]
==
state_dict
[
'param_groups'
]
state_dict
=
deepcopy
(
state_dict
)
groups
=
optimizer
.
param_groups
saved_groups
=
state_dict
[
'param_groups'
]
idx_to_p
:
Dict
[
str
,
Parameter
]
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
groups
)))
}
missing_keys
=
list
(
set
(
idx_to_p
.
keys
())
-
set
(
state_dict
[
'state'
].
keys
()))
unexpected_keys
=
[]
error_msgs
=
[]
for
idx
,
state
in
state_dict
[
'state'
].
items
():
if
idx
in
idx_to_p
:
old_state
=
optimizer
.
state
[
idx_to_p
[
idx
]]
copy_optimizer_state
(
state
,
old_state
)
else
:
unexpected_keys
.
append
(
idx
)
if
strict
:
if
len
(
unexpected_keys
)
>
0
:
error_msgs
.
insert
(
0
,
'Unexpected key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
unexpected_keys
)))
if
len
(
missing_keys
)
>
0
:
error_msgs
.
insert
(
0
,
'Missing key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
)))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
optimizer
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
colossalai/utils/checkpoint_io/writer.py
0 → 100644
View file @
99870726
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
.constant
import
MODEL_CKPT_FILE_NAME
,
OPTIM_CKPT_FILE_NAME
,
META_CKPT_FILE_NAME
,
OTHER_CKPT_FILE_NAME
,
GLOBAL_META_FILE_NAME
import
torch
import
os
class
CheckpointWriter
(
ABC
):
def
__init__
(
self
,
base_name
:
str
,
overwrite
:
bool
=
False
,
rank
:
int
=
0
,
world_size
:
int
=
1
)
->
None
:
super
().
__init__
()
self
.
base_name
=
base_name
self
.
overwrite
=
overwrite
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
is_distributed
=
world_size
>
1
self
.
is_main_process
=
rank
==
0
@
abstractmethod
def
write
(
self
,
name
:
str
,
state_dict
:
dict
)
->
None
:
pass
@
abstractmethod
def
save_model
(
self
,
model_checkpoint
:
dict
)
->
None
:
pass
@
abstractmethod
def
save_optimizer
(
self
,
optimizer_checkpoint
:
dict
)
->
None
:
pass
@
abstractmethod
def
save_meta
(
self
,
meta_checkpoint
:
dict
)
->
None
:
pass
@
abstractmethod
def
save_others
(
self
,
kwargs
:
dict
)
->
None
:
pass
class
DiskCheckpointWriter
(
CheckpointWriter
):
def
__init__
(
self
,
base_name
:
str
,
overwrite
:
bool
=
False
,
rank
:
int
=
0
,
world_size
:
int
=
1
)
->
None
:
super
().
__init__
(
base_name
,
overwrite
,
rank
,
world_size
)
if
not
os
.
path
.
exists
(
base_name
):
os
.
makedirs
(
base_name
)
assert
os
.
path
.
isdir
(
base_name
),
f
'"
{
base_name
}
" is not a directory'
self
.
model_checkpoint_names
=
[]
self
.
optimizer_checkpoint_names
=
[]
self
.
is_meta_saved
:
bool
=
False
self
.
_save_global_meta
()
def
write
(
self
,
name
:
str
,
state_dict
:
dict
)
->
None
:
path
=
os
.
path
.
join
(
self
.
base_name
,
name
)
if
os
.
path
.
exists
(
path
)
and
not
self
.
overwrite
:
raise
RuntimeError
(
f
'Save error: Checkpoint "
{
path
}
" exists. (overwrite = False)'
)
torch
.
save
(
state_dict
,
path
)
def
_save_global_meta
(
self
)
->
None
:
if
self
.
is_main_process
:
global_meta
=
{
'meta'
:
[]}
if
self
.
is_distributed
:
for
i
in
range
(
self
.
world_size
):
global_meta
[
'meta'
].
append
(
META_CKPT_FILE_NAME
.
replace
(
'.bin'
,
f
'-rank
{
i
}
.bin'
))
else
:
global_meta
[
'meta'
].
append
(
META_CKPT_FILE_NAME
)
self
.
write
(
GLOBAL_META_FILE_NAME
,
global_meta
)
def
_get_checkpoint_name
(
self
,
base_name
:
str
,
shard_idx
:
Optional
[
int
]
=
None
)
->
str
:
checkpoint_name
=
base_name
if
self
.
is_distributed
:
checkpoint_name
=
checkpoint_name
.
replace
(
'.bin'
,
f
'-rank
{
self
.
rank
}
.bin'
)
if
shard_idx
is
not
None
:
checkpoint_name
=
checkpoint_name
.
replace
(
'.bin'
,
f
'-shard
{
shard_idx
}
.bin'
)
return
checkpoint_name
def
save_model
(
self
,
model_checkpoint
:
dict
)
->
None
:
assert
not
self
.
is_meta_saved
,
'Cannot save model after saving meta'
name
=
self
.
_get_checkpoint_name
(
MODEL_CKPT_FILE_NAME
,
len
(
self
.
model_checkpoint_names
))
self
.
write
(
name
,
model_checkpoint
)
self
.
model_checkpoint_names
.
append
(
name
)
def
save_optimizer
(
self
,
optimizer_checkpoint
:
dict
)
->
None
:
assert
not
self
.
is_meta_saved
,
'Cannot save optimizer after saving meta'
name
=
self
.
_get_checkpoint_name
(
OPTIM_CKPT_FILE_NAME
,
len
(
self
.
optimizer_checkpoint_names
))
self
.
write
(
name
,
optimizer_checkpoint
)
self
.
optimizer_checkpoint_names
.
append
(
name
)
def
save_meta
(
self
,
meta_checkpoint
:
dict
)
->
None
:
if
len
(
self
.
model_checkpoint_names
)
>
0
:
meta_checkpoint
[
'model'
]
=
self
.
model_checkpoint_names
if
len
(
self
.
optimizer_checkpoint_names
)
>
0
:
meta_checkpoint
[
'optimizer'
]
=
self
.
optimizer_checkpoint_names
self
.
write
(
self
.
_get_checkpoint_name
(
META_CKPT_FILE_NAME
),
meta_checkpoint
)
self
.
is_meta_saved
=
True
def
save_others
(
self
,
kwargs
:
dict
)
->
None
:
if
self
.
is_main_process
:
self
.
write
(
OTHER_CKPT_FILE_NAME
,
kwargs
)
tests/test_utils/test_checkpoint_io/test_build_checkpoints.py
0 → 100644
View file @
99870726
import
torch
import
torch.nn
as
nn
from
colossalai.utils.checkpoint_io.meta
import
ParamDistMeta
from
colossalai.utils.checkpoint_io.utils
import
build_checkpoints
from
torch.optim
import
Adam
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
20
,
1
)
def
test_global_model
():
model
=
DummyModel
()
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
0
,
model
)
assert
len
(
model_checkpoints
)
==
1
assert
len
(
optimizer_checkpoints
)
==
0
assert
meta
[
'dist_meta'
]
is
None
orig_state_dict
=
model
.
state_dict
()
global_state_dict
=
model_checkpoints
[
0
]
assert
set
(
orig_state_dict
.
keys
())
==
set
(
global_state_dict
.
keys
())
for
k
,
v
in
orig_state_dict
.
items
():
assert
torch
.
equal
(
v
,
global_state_dict
[
k
])
def
test_global_model_shard
():
model
=
DummyModel
()
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
80
,
model
)
assert
len
(
model_checkpoints
)
==
2
assert
len
(
optimizer_checkpoints
)
==
0
assert
meta
[
'dist_meta'
]
is
None
orig_state_dict
=
model
.
state_dict
()
assert
set
(
orig_state_dict
.
keys
())
==
set
(
model_checkpoints
[
0
].
keys
())
|
set
(
model_checkpoints
[
1
].
keys
())
assert
len
(
set
(
model_checkpoints
[
0
].
keys
())
&
set
(
model_checkpoints
[
1
].
keys
()))
==
0
for
k
,
v
in
orig_state_dict
.
items
():
for
state_dict
in
model_checkpoints
:
if
k
in
state_dict
:
assert
torch
.
equal
(
v
,
state_dict
[
k
])
def
test_global_optimizer
():
model
=
DummyModel
()
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
0
,
model
,
optimizer
)
assert
len
(
optimizer_checkpoints
)
==
1
assert
meta
[
'param_to_os'
]
==
{
'fc.weight'
:
0
,
'fc.bias'
:
1
}
for
state
in
meta
[
'paired_os'
].
values
():
for
k
,
is_paired
in
state
.
items
():
if
k
==
'step'
:
assert
not
is_paired
else
:
assert
is_paired
orig_state_dict
=
optimizer
.
state_dict
()
state_dict
=
optimizer_checkpoints
[
0
]
for
k
,
orig_state
in
orig_state_dict
[
'state'
].
items
():
state
=
state_dict
[
'state'
][
k
]
for
v1
,
v2
in
zip
(
orig_state
.
values
(),
state
.
values
()):
if
isinstance
(
v2
,
torch
.
Tensor
):
assert
torch
.
equal
(
v1
,
v2
)
else
:
assert
v2
==
v2
assert
orig_state_dict
[
'param_groups'
]
==
state_dict
[
'param_groups'
]
def
test_global_optimizer_shard
():
model
=
DummyModel
()
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
80
,
model
,
optimizer
)
assert
len
(
optimizer_checkpoints
)
==
2
assert
'param_groups'
in
optimizer_checkpoints
[
0
]
and
'param_groups'
not
in
optimizer_checkpoints
[
1
]
orig_state_dict
=
optimizer
.
state_dict
()
assert
set
(
orig_state_dict
[
'state'
].
keys
())
==
set
(
optimizer_checkpoints
[
0
][
'state'
].
keys
())
|
set
(
optimizer_checkpoints
[
1
][
'state'
].
keys
())
assert
len
(
set
(
optimizer_checkpoints
[
0
][
'state'
].
keys
())
&
set
(
optimizer_checkpoints
[
1
][
'state'
].
keys
()))
==
0
for
k
,
orig_state
in
orig_state_dict
[
'state'
].
items
():
state
=
optimizer_checkpoints
[
0
][
'state'
][
k
]
if
k
in
optimizer_checkpoints
[
0
][
'state'
]
else
optimizer_checkpoints
[
1
][
'state'
][
k
]
for
v1
,
v2
in
zip
(
orig_state
.
values
(),
state
.
values
()):
if
isinstance
(
v2
,
torch
.
Tensor
):
assert
torch
.
equal
(
v1
,
v2
)
else
:
assert
v1
==
v2
assert
orig_state_dict
[
'param_groups'
]
==
optimizer_checkpoints
[
0
][
'param_groups'
]
def
test_dist_model_optimizer
():
model
=
DummyModel
()
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
dist_meta
=
{
'fc.weight'
:
ParamDistMeta
(
0
,
2
,
0
,
1
),
'fc.bias'
:
ParamDistMeta
(
1
,
2
,
0
,
1
)}
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
0
,
model
,
optimizer
,
dist_meta
=
dist_meta
)
assert
dist_meta
==
meta
[
'dist_meta'
]
assert
len
(
model_checkpoints
)
==
1
assert
len
(
optimizer_checkpoints
)
==
1
assert
'fc.weight'
in
model_checkpoints
[
0
]
and
'fc.bias'
in
model_checkpoints
[
0
]
assert
0
in
optimizer_checkpoints
[
0
][
'state'
]
and
1
in
optimizer_checkpoints
[
0
][
'state'
]
dist_meta
=
{
'fc.weight'
:
ParamDistMeta
(
1
,
2
,
0
,
1
),
'fc.bias'
:
ParamDistMeta
(
1
,
2
,
0
,
1
)}
model_checkpoints
,
optimizer_checkpoints
,
meta
=
build_checkpoints
(
0
,
model
,
optimizer
,
dist_meta
=
dist_meta
)
assert
dist_meta
==
meta
[
'dist_meta'
]
assert
len
(
model_checkpoints
)
==
1
assert
len
(
optimizer_checkpoints
)
==
1
if
__name__
==
'__main__'
:
test_global_model
()
test_global_model_shard
()
test_global_optimizer
()
test_global_optimizer_shard
()
test_dist_model_optimizer
()
tests/test_utils/test_checkpoint_io/test_load.py
0 → 100644
View file @
99870726
from
copy
import
deepcopy
from
functools
import
partial
from
tempfile
import
TemporaryDirectory
from
typing
import
Dict
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.checkpoint_io.io
import
load
,
save
from
colossalai.utils.checkpoint_io.meta
import
(
ParamDistMeta
,
ParamRedistMeta
,
RankRedistMeta
,
RedistMeta
)
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.optim
import
Adam
,
Optimizer
def
check_model_state_dict
(
a
:
Dict
[
str
,
Tensor
],
b
:
Dict
[
str
,
Tensor
])
->
None
:
assert
set
(
a
.
keys
())
==
set
(
b
.
keys
())
for
k
,
v
in
a
.
items
():
assert
torch
.
equal
(
v
,
b
[
k
])
def
check_optim_state_dict
(
a
:
dict
,
b
:
dict
,
ignore_param_gruops
:
bool
=
False
)
->
None
:
assert
set
(
a
[
'state'
].
keys
())
==
set
(
b
[
'state'
].
keys
())
for
k
,
state
in
a
[
'state'
].
items
():
b_state
=
b
[
'state'
][
k
]
for
v1
,
v2
in
zip
(
state
.
values
(),
b_state
.
values
()):
if
isinstance
(
v1
,
Tensor
):
assert
torch
.
equal
(
v1
,
v2
)
else
:
assert
v1
==
v2
if
not
ignore_param_gruops
:
assert
a
[
'param_groups'
]
==
b
[
'param_groups'
]
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
20
,
1
)
def
prepare_model_optim
(
shard
:
bool
=
False
,
zero
:
bool
=
False
):
model
=
DummyModel
()
if
shard
:
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
chunk
(
2
,
1
)[
dist
.
get_rank
()
%
2
]
if
zero
:
dp_rank
=
dist
.
get_rank
()
//
2
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
reshape
(
-
1
).
split
([
3
,
model
.
fc
.
weight
.
size
(
1
)
-
3
],
0
)[
dp_rank
]
if
dp_rank
!=
0
:
model
.
fc
.
bias
.
data
=
torch
.
empty
(
0
,
dtype
=
model
.
fc
.
bias
.
dtype
)
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
return
model
,
optimizer
def
reset_model_optim
(
model
:
Module
,
optimizer
:
Optimizer
,
scalar
:
float
=
0.0
):
with
torch
.
no_grad
():
for
p
in
model
.
parameters
():
p
.
fill_
(
scalar
)
for
state
in
optimizer
.
state
.
values
():
for
v
in
state
.
values
():
if
isinstance
(
v
,
Tensor
):
v
.
fill_
(
scalar
)
def
get_dist_metas
(
nprocs
:
int
,
zero
:
bool
=
False
):
dp_world_size
=
nprocs
//
2
dist_metas
=
[]
for
rank
in
range
(
nprocs
):
if
zero
:
dist_metas
.
append
({
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
],
zero_numel
=
10
,
zero_orig_shape
=
[
1
,
10
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
,
zero_numel
=
1
,
zero_orig_shape
=
[
1
])
})
else
:
dist_metas
.
append
({
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
)
})
return
dist_metas
def
get_redist_meta
(
nprocs
:
int
):
dp_world_size
=
nprocs
//
2
rank_meta
=
{
'fc.weight'
:
{
rank
:
RankRedistMeta
(
rank
//
2
,
rank
%
2
,
0
)
for
rank
in
range
(
nprocs
)},
'fc.bias'
:
{
rank
:
RankRedistMeta
(
rank
//
2
,
0
,
0
)
for
rank
in
range
(
nprocs
)}
}
param_meta
=
{
'fc.weight'
:
ParamRedistMeta
(
dp_world_size
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
]),
'fc.bias'
:
ParamRedistMeta
(
dp_world_size
,
1
)
}
return
RedistMeta
(
rank_meta
,
[],
param_meta
)
@
pytest
.
mark
.
parametrize
(
'max_shard_size_gb'
,
[
80
/
1024
**
3
,
0
])
def
test_save_global_load_global
(
max_shard_size_gb
:
float
):
model
,
optimizer
=
prepare_model_optim
()
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
,
max_shard_size_gb
=
max_shard_size_gb
)
new_model
,
new_optimizer
=
prepare_model_optim
()
load
(
dir_name
,
new_model
,
new_optimizer
,
max_shard_size_gb
=
max_shard_size_gb
)
check_model_state_dict
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_optim_state_dict
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
def
run_dist
(
rank
,
world_size
,
port
,
func
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
func
()
def
launch_dist
(
fn
,
world_size
:
int
):
proc_fn
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
fn
)
mp
.
spawn
(
proc_fn
,
nprocs
=
world_size
)
def
save_dist
(
dir_name
:
str
,
zero
:
bool
):
model
,
optmizer
=
prepare_model_optim
(
shard
=
True
,
zero
=
zero
)
reset_model_optim
(
model
,
optmizer
)
world_size
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
save
(
dir_name
,
model
,
optmizer
,
dist_meta
=
get_dist_metas
(
world_size
,
zero
)[
rank
])
def
load_and_check_dist
(
dir_name
:
str
):
world_size
=
dist
.
get_world_size
()
model
,
optmizer
=
prepare_model_optim
(
shard
=
True
)
reset_model_optim
(
model
,
optmizer
)
model_state_dict
=
deepcopy
(
model
.
state_dict
())
optimizer_state_dict
=
deepcopy
(
optmizer
.
state_dict
())
reset_model_optim
(
model
,
optmizer
,
1
)
load
(
dir_name
,
model
,
optmizer
,
get_redist_meta
(
world_size
),
get_dist_metas
(
world_size
))
check_model_state_dict
(
model_state_dict
,
model
.
state_dict
())
check_optim_state_dict
(
optimizer_state_dict
,
optmizer
.
state_dict
())
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_save_global_load_dist
():
model
,
optimizer
=
prepare_model_optim
()
reset_model_optim
(
model
,
optimizer
)
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
)
fn
=
partial
(
load_and_check_dist
,
dir_name
)
launch_dist
(
fn
,
4
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_save_dist_load_dist
():
with
TemporaryDirectory
()
as
dir_name
:
# save tp + dp
fn
=
partial
(
save_dist
,
dir_name
,
False
)
launch_dist
(
fn
,
2
)
# load tp + dp
fn
=
partial
(
load_and_check_dist
,
dir_name
)
launch_dist
(
fn
,
2
)
with
TemporaryDirectory
()
as
dir_name
:
# save tp + zero
fn
=
partial
(
save_dist
,
dir_name
,
True
)
launch_dist
(
fn
,
4
)
# load tp + dp
fn
=
partial
(
load_and_check_dist
,
dir_name
)
launch_dist
(
fn
,
2
)
launch_dist
(
fn
,
4
)
if
__name__
==
'__main__'
:
test_save_global_load_global
(
80
/
1024
**
3
)
test_save_global_load_global
(
0
)
test_save_global_load_dist
()
test_save_dist_load_dist
()
tests/test_utils/test_checkpoint_io/test_merge.py
0 → 100644
View file @
99870726
from
colossalai.utils.checkpoint_io.meta
import
ParamDistMeta
from
colossalai.utils.checkpoint_io.constant
import
GLOBAL_META_FILE_NAME
from
colossalai.utils.checkpoint_io.io
import
save
,
merge
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
tempfile
import
TemporaryDirectory
from
torch.optim
import
Adam
from
functools
import
partial
import
torch
import
os
import
pytest
import
colossalai
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
20
,
1
)
def
prepare_model_optim
(
shard
:
bool
=
False
,
zero
:
bool
=
False
):
model
=
DummyModel
()
if
shard
:
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
chunk
(
2
,
1
)[
dist
.
get_rank
()
%
2
]
if
zero
:
dp_rank
=
dist
.
get_rank
()
//
2
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
reshape
(
-
1
).
split
([
3
,
model
.
fc
.
weight
.
size
(
1
)
-
3
],
0
)[
dp_rank
]
if
dp_rank
!=
0
:
model
.
fc
.
bias
.
data
=
torch
.
empty
(
0
,
dtype
=
model
.
fc
.
bias
.
dtype
)
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
ones_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
return
model
,
optimizer
def
test_merge_global
():
model
,
optimizer
=
prepare_model_optim
()
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
)
with
TemporaryDirectory
()
as
output_dir
:
merge
(
dir_name
,
output_dir
)
assert
len
(
os
.
listdir
(
output_dir
))
==
0
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
,
max_shard_size_gb
=
80
/
1024
**
3
)
with
TemporaryDirectory
()
as
output_dir
:
merge
(
dir_name
,
output_dir
)
assert
len
(
os
.
listdir
(
output_dir
))
==
0
def
run_dist
(
rank
,
world_size
,
port
,
func
):
colossalai
.
launch
(
config
=
{
'parallel'
:
{
'tensor'
:
{
'mode'
:
'1d'
,
'size'
:
2
}
}},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
func
()
def
run_save_dist
(
dir_name
:
str
,
zero
:
bool
):
model
,
optmizer
=
prepare_model_optim
(
shard
=
True
,
zero
=
zero
)
rank
=
dist
.
get_rank
()
dp_world_size
=
dist
.
get_world_size
()
//
2
if
not
zero
:
dist_metas
=
{
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
)
}
else
:
dist_metas
=
{
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
],
zero_numel
=
10
,
zero_orig_shape
=
[
1
,
10
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
,
zero_numel
=
1
,
zero_orig_shape
=
[
1
])
}
save
(
dir_name
,
model
,
optmizer
,
dist_meta
=
dist_metas
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"zero"
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
def
test_merge_tp_dp
(
zero
:
bool
):
with
TemporaryDirectory
()
as
dir_name
:
fn
=
partial
(
run_save_dist
,
dir_name
,
zero
)
world_size
=
4
proc_fn
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
fn
)
mp
.
spawn
(
proc_fn
,
nprocs
=
world_size
)
with
TemporaryDirectory
()
as
output_dir
:
merge
(
dir_name
,
output_dir
)
assert
len
(
os
.
listdir
(
output_dir
))
==
5
global_meta
=
torch
.
load
(
os
.
path
.
join
(
output_dir
,
GLOBAL_META_FILE_NAME
))
assert
len
(
global_meta
[
'meta'
])
==
1
meta
=
torch
.
load
(
os
.
path
.
join
(
output_dir
,
global_meta
[
'meta'
][
0
]))
assert
meta
[
'dist_meta'
]
is
None
assert
len
(
meta
[
'params'
])
==
2
assert
len
(
meta
[
'model'
])
==
1
and
len
(
meta
[
'optimizer'
])
==
1
model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
output_dir
,
meta
[
'model'
][
0
]))
assert
len
(
model_state_dict
)
==
2
assert
model_state_dict
[
'fc.weight'
].
size
(
1
)
==
20
optimizer_state_dict
=
torch
.
load
(
os
.
path
.
join
(
output_dir
,
meta
[
'optimizer'
][
0
]))
assert
len
(
optimizer_state_dict
[
'state'
])
==
2
assert
'param_groups'
in
optimizer_state_dict
and
'state'
in
optimizer_state_dict
assert
optimizer_state_dict
[
'state'
][
0
][
'exp_avg'
].
size
(
1
)
==
20
assert
optimizer_state_dict
[
'state'
][
0
][
'exp_avg_sq'
].
size
(
1
)
==
20
if
__name__
==
'__main__'
:
test_merge_global
()
test_merge_tp_dp
(
False
)
test_merge_tp_dp
(
True
)
tests/test_utils/test_checkpoint_io/test_merge_param.py
0 → 100644
View file @
99870726
import
torch
from
colossalai.utils.checkpoint_io.meta
import
ParamDistMeta
from
colossalai.utils.checkpoint_io.distributed
import
unflatten_zero_param
,
gather_tp_param
,
merge_param
def
test_unflatten_zero_param_even
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
i
,
4
,
0
,
1
,
zero_numel
=
16
,
zero_orig_shape
=
[
4
,
4
])
for
i
in
range
(
4
)]
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
list
(
orig_tensor
.
reshape
(
-
1
).
chunk
(
4
))
unflattened_tensor
=
unflatten_zero_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
unflattened_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_unflatten_zero_param_uneven
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
i
,
4
,
0
,
1
,
zero_numel
=
16
,
zero_orig_shape
=
[
4
,
4
])
for
i
in
range
(
1
,
3
)]
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
list
(
orig_tensor
.
reshape
(
-
1
).
split
([
13
,
3
]))
unflattened_tensor
=
unflatten_zero_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
unflattened_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_gather_tp_param_1d_row
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
0
,
1
,
i
,
4
,
tp_shard_dims
=
[
0
],
tp_num_parts
=
[
4
])
for
i
in
range
(
4
)]
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
[
t
.
contiguous
()
for
t
in
orig_tensor
.
chunk
(
4
,
0
)]
gathered_tensor
=
gather_tp_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
gathered_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_gather_tp_param_1d_col
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
0
,
1
,
i
,
4
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
4
])
for
i
in
range
(
4
)]
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
[
t
.
contiguous
()
for
t
in
orig_tensor
.
chunk
(
4
,
1
)]
gathered_tensor
=
gather_tp_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
gathered_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_gather_tp_param_2d
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
0
,
1
,
i
,
6
,
tp_shard_dims
=
[
0
,
1
],
tp_num_parts
=
[
2
,
3
])
for
i
in
range
(
6
)]
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
t
.
contiguous
()
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)]
gathered_tensor
=
gather_tp_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
gathered_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_gather_tp_param_2d_reverse
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
0
,
1
,
i
,
6
,
tp_shard_dims
=
[
1
,
0
],
tp_num_parts
=
[
3
,
2
])
for
i
in
range
(
6
)]
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
t
.
contiguous
()
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)]
gathered_tensor
=
gather_tp_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
gathered_tensor
)
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_merge_param_hybrid
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
i
%
2
,
2
,
i
//
2
,
6
,
tp_shard_dims
=
[
1
,
0
],
tp_num_parts
=
[
3
,
2
],
zero_numel
=
4
,
zero_orig_shape
=
[
2
,
2
])
for
i
in
range
(
12
)
]
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
chunk
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)
for
chunk
in
t
.
contiguous
().
reshape
(
-
1
).
split
([
1
,
3
])
]
merged_tensor
=
merge_param
(
tensors
,
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
def
test_merge_param_dummy
()
->
None
:
dist_metas
=
[
ParamDistMeta
(
0
,
1
,
0
,
1
)]
orig_tensor
=
torch
.
rand
(
4
,
6
)
merged_tensor
=
merge_param
([
orig_tensor
],
dist_metas
)
assert
torch
.
equal
(
orig_tensor
,
merged_tensor
)
if
__name__
==
'__main__'
:
test_unflatten_zero_param_even
()
test_unflatten_zero_param_uneven
()
test_gather_tp_param_1d_row
()
test_gather_tp_param_1d_col
()
test_gather_tp_param_2d
()
test_gather_tp_param_2d_reverse
()
test_merge_param_hybrid
()
test_merge_param_dummy
()
tests/test_utils/test_checkpoint_io/test_redist.py
0 → 100644
View file @
99870726
import
os
from
functools
import
partial
from
tempfile
import
TemporaryDirectory
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.checkpoint_io.constant
import
GLOBAL_META_FILE_NAME
from
colossalai.utils.checkpoint_io.io
import
redist
,
save
from
colossalai.utils.checkpoint_io.meta
import
(
ParamDistMeta
,
ParamRedistMeta
,
PipelineRedistMeta
,
RankRedistMeta
,
RedistMeta
)
from
torch.optim
import
Adam
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
20
,
1
)
def
prepare_model_optim
(
shard
:
bool
=
False
,
zero
:
bool
=
False
):
model
=
DummyModel
()
if
shard
:
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
chunk
(
2
,
1
)[
dist
.
get_rank
()
%
2
]
if
zero
:
dp_rank
=
dist
.
get_rank
()
//
2
model
.
fc
.
weight
.
data
=
model
.
fc
.
weight
.
reshape
(
-
1
).
split
([
3
,
model
.
fc
.
weight
.
size
(
1
)
-
3
],
0
)[
dp_rank
]
if
dp_rank
!=
0
:
model
.
fc
.
bias
.
data
=
torch
.
empty
(
0
,
dtype
=
model
.
fc
.
bias
.
dtype
)
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
ones_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
return
model
,
optimizer
def
get_dist_metas
(
nprocs
:
int
,
zero
:
bool
=
False
):
dp_world_size
=
nprocs
//
2
dist_metas
=
[]
for
rank
in
range
(
nprocs
):
if
zero
:
dist_metas
.
append
({
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
],
zero_numel
=
10
,
zero_orig_shape
=
[
1
,
10
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
,
zero_numel
=
1
,
zero_orig_shape
=
[
1
])
})
else
:
dist_metas
.
append
({
'fc.weight'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
rank
%
2
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
]),
'fc.bias'
:
ParamDistMeta
(
rank
//
2
,
dp_world_size
,
0
,
1
)
})
return
dist_metas
def
get_redist_meta
(
nprocs
:
int
):
dp_world_size
=
nprocs
//
2
rank_meta
=
{
'fc.weight'
:
{
rank
:
RankRedistMeta
(
rank
//
2
,
rank
%
2
,
0
)
for
rank
in
range
(
nprocs
)},
'fc.bias'
:
{
rank
:
RankRedistMeta
(
rank
//
2
,
0
,
0
)
for
rank
in
range
(
nprocs
)}
}
param_meta
=
{
'fc.weight'
:
ParamRedistMeta
(
dp_world_size
,
2
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
2
]),
'fc.bias'
:
ParamRedistMeta
(
dp_world_size
,
1
)
}
return
RedistMeta
(
rank_meta
,
[],
param_meta
)
def
check_checkpoint_shape
(
dir_name
:
str
):
global_meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
GLOBAL_META_FILE_NAME
))
for
meta_name
in
global_meta
[
'meta'
]:
meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta_name
))
assert
meta
[
'dist_meta'
]
is
not
None
assert
len
(
meta
[
'params'
])
==
2
assert
len
(
meta
[
'model'
])
==
1
and
len
(
meta
[
'optimizer'
])
==
1
model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'model'
][
0
]))
assert
len
(
model_state_dict
)
==
2
assert
model_state_dict
[
'fc.weight'
].
size
(
1
)
==
10
optimizer_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'optimizer'
][
0
]))
assert
len
(
optimizer_state_dict
[
'state'
])
==
2
assert
'param_groups'
in
optimizer_state_dict
and
'state'
in
optimizer_state_dict
assert
optimizer_state_dict
[
'state'
][
0
][
'exp_avg'
].
size
(
1
)
==
10
assert
optimizer_state_dict
[
'state'
][
0
][
'exp_avg_sq'
].
size
(
1
)
==
10
def
test_global_to_dist
():
model
,
optimizer
=
prepare_model_optim
()
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
)
with
TemporaryDirectory
()
as
output_dir
:
redist
(
dir_name
,
output_dir
,
get_redist_meta
(
4
),
get_dist_metas
(
4
))
check_checkpoint_shape
(
output_dir
)
def
run_dist
(
rank
,
world_size
,
port
,
func
):
colossalai
.
launch
(
config
=
{
'parallel'
:
{
'tensor'
:
{
'mode'
:
'1d'
,
'size'
:
2
}
}},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
func
()
def
run_save_dist
(
dir_name
:
str
,
zero
:
bool
):
model
,
optmizer
=
prepare_model_optim
(
shard
=
True
,
zero
=
zero
)
rank
=
dist
.
get_rank
()
save
(
dir_name
,
model
,
optmizer
,
dist_meta
=
get_dist_metas
(
4
,
zero
)[
rank
])
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"zero"
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
def
test_dist_to_dist
(
zero
:
bool
):
with
TemporaryDirectory
()
as
dir_name
:
fn
=
partial
(
run_save_dist
,
dir_name
,
zero
)
world_size
=
4
proc_fn
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
fn
)
mp
.
spawn
(
proc_fn
,
nprocs
=
world_size
)
with
TemporaryDirectory
()
as
output_dir
:
redist
(
dir_name
,
output_dir
,
get_redist_meta
(
4
),
get_dist_metas
(
4
))
if
not
zero
:
assert
len
(
os
.
listdir
(
output_dir
))
==
0
else
:
check_checkpoint_shape
(
output_dir
)
if
__name__
==
'__main__'
:
test_global_to_dist
()
test_dist_to_dist
(
False
)
test_dist_to_dist
(
True
)
tests/test_utils/test_checkpoint_io/test_save.py
0 → 100644
View file @
99870726
import
os
from
functools
import
partial
from
tempfile
import
TemporaryDirectory
from
typing
import
Dict
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.checkpoint_io.constant
import
(
GLOBAL_META_FILE_NAME
,
META_CKPT_FILE_NAME
,
MODEL_CKPT_FILE_NAME
,
OTHER_CKPT_FILE_NAME
)
from
colossalai.utils.checkpoint_io.io
import
save
from
colossalai.utils.checkpoint_io.meta
import
ParamDistMeta
from
torch
import
Tensor
from
torch.optim
import
Adam
def
check_model_state_dict
(
a
:
Dict
[
str
,
Tensor
],
b
:
Dict
[
str
,
Tensor
])
->
None
:
assert
set
(
a
.
keys
())
==
set
(
b
.
keys
())
for
k
,
v
in
a
.
items
():
assert
torch
.
equal
(
v
,
b
[
k
])
def
check_optim_state_dict
(
a
:
dict
,
b
:
dict
,
ignore_param_gruops
:
bool
=
False
)
->
None
:
assert
set
(
a
[
'state'
].
keys
())
==
set
(
b
[
'state'
].
keys
())
for
k
,
state
in
a
[
'state'
].
items
():
b_state
=
b
[
'state'
][
k
]
for
v1
,
v2
in
zip
(
state
.
values
(),
b_state
.
values
()):
if
isinstance
(
v1
,
Tensor
):
assert
torch
.
equal
(
v1
,
v2
)
else
:
assert
v1
==
v2
if
not
ignore_param_gruops
:
assert
a
[
'param_groups'
]
==
b
[
'param_groups'
]
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
20
,
1
)
def
prepare_model_optim
():
model
=
DummyModel
()
for
p
in
model
.
parameters
():
p
.
grad
=
torch
.
ones_like
(
p
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
.
step
()
return
model
,
optimizer
def
test_overwrite
():
model
=
DummyModel
()
with
TemporaryDirectory
()
as
dir_name
:
with
open
(
os
.
path
.
join
(
dir_name
,
MODEL_CKPT_FILE_NAME
.
replace
(
'.bin'
,
'-shard0.bin'
)),
'a'
)
as
f
:
pass
with
pytest
.
raises
(
RuntimeError
,
match
=
r
'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'
):
save
(
dir_name
,
model
)
def
test_save_global
():
model
,
optimizer
=
prepare_model_optim
()
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
)
assert
len
(
os
.
listdir
(
dir_name
))
==
5
global_meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
GLOBAL_META_FILE_NAME
))
assert
len
(
global_meta
[
'meta'
])
==
1
and
global_meta
[
'meta'
][
0
]
==
META_CKPT_FILE_NAME
meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
META_CKPT_FILE_NAME
))
assert
len
(
meta
[
'model'
])
==
1
assert
len
(
meta
[
'optimizer'
])
==
1
model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'model'
][
0
]))
check_model_state_dict
(
model
.
state_dict
(),
model_state_dict
)
optimizer_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'optimizer'
][
0
]))
check_optim_state_dict
(
optimizer
.
state_dict
(),
optimizer_state_dict
)
other_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
OTHER_CKPT_FILE_NAME
))
assert
len
(
other_state_dict
)
==
0
def
test_save_global_shard
():
model
,
optimizer
=
prepare_model_optim
()
with
TemporaryDirectory
()
as
dir_name
:
save
(
dir_name
,
model
,
optimizer
,
max_shard_size_gb
=
80
/
1024
**
3
)
assert
len
(
os
.
listdir
(
dir_name
))
==
7
meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
META_CKPT_FILE_NAME
))
assert
len
(
meta
[
'model'
])
==
2
and
len
(
meta
[
'optimizer'
])
==
2
model_state_dicts
=
[
torch
.
load
(
os
.
path
.
join
(
dir_name
,
name
))
for
name
in
meta
[
'model'
]]
assert
len
(
set
(
model_state_dicts
[
0
].
keys
())
&
set
(
model_state_dicts
[
1
].
keys
()))
==
0
check_model_state_dict
(
model
.
state_dict
(),
{
**
model_state_dicts
[
0
],
**
model_state_dicts
[
1
]})
optimizer_state_dicts
=
[
torch
.
load
(
os
.
path
.
join
(
dir_name
,
name
))
for
name
in
meta
[
'optimizer'
]]
assert
len
(
set
(
optimizer_state_dicts
[
0
][
'state'
].
keys
())
&
set
(
optimizer_state_dicts
[
1
][
'state'
].
keys
()))
==
0
assert
'param_groups'
in
optimizer_state_dicts
[
0
]
and
'param_groups'
not
in
optimizer_state_dicts
[
1
]
check_optim_state_dict
(
optimizer
.
state_dict
(),
{
'state'
:
{
**
optimizer_state_dicts
[
0
][
'state'
],
**
optimizer_state_dicts
[
1
][
'state'
]
},
'param_groups'
:
optimizer_state_dicts
[
0
][
'param_groups'
]
})
def
run_dist
(
rank
,
world_size
,
port
,
func
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
func
()
def
run_save_dist
(
dir_name
):
model
,
optmizer
=
prepare_model_optim
()
dist_metas
=
{
'fc.weight'
:
ParamDistMeta
(
dist
.
get_rank
(),
dist
.
get_world_size
(),
0
,
1
),
'fc.bias'
:
ParamDistMeta
(
dist
.
get_rank
(),
dist
.
get_world_size
(),
0
,
1
)
}
save
(
dir_name
,
model
,
optmizer
,
dist_meta
=
dist_metas
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_save_dist
():
with
TemporaryDirectory
()
as
dir_name
:
fn
=
partial
(
run_save_dist
,
dir_name
)
world_size
=
2
proc_fn
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
fn
)
mp
.
spawn
(
proc_fn
,
nprocs
=
world_size
)
assert
len
(
os
.
listdir
(
dir_name
))
==
8
global_meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
GLOBAL_META_FILE_NAME
))
assert
len
(
global_meta
[
'meta'
])
==
2
for
rank
,
meta_name
in
enumerate
(
global_meta
[
'meta'
]):
meta
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta_name
))
assert
meta
.
get
(
'dist_meta'
,
None
)
is
not
None
assert
len
(
meta
[
'model'
])
==
1
and
len
(
meta
[
'optimizer'
])
==
1
model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'model'
][
0
]))
assert
len
(
model_state_dict
)
==
2
optimizer_state_dict
=
torch
.
load
(
os
.
path
.
join
(
dir_name
,
meta
[
'optimizer'
][
0
]))
assert
len
(
optimizer_state_dict
[
'state'
])
==
2
assert
'param_groups'
in
optimizer_state_dict
if
__name__
==
'__main__'
:
test_overwrite
()
test_save_global
()
test_save_global_shard
()
test_save_dist
()
tests/test_utils/test_checkpoint_io/test_unmerge_param.py
0 → 100644
View file @
99870726
import
torch
from
colossalai.utils.checkpoint_io.meta
import
ParamRedistMeta
from
colossalai.utils.checkpoint_io.distributed
import
flatten_zero_param
,
split_tp_param
,
unmerge_param
def
test_flatten_zero_param_even
()
->
None
:
redist_meta
=
ParamRedistMeta
(
4
,
1
,
zero_start_dp_rank
=
0
,
zero_offsets
=
[
0
,
4
,
8
,
12
])
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
list
(
orig_tensor
.
reshape
(
-
1
).
chunk
(
4
))
flat_tensors
=
flatten_zero_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
flat_tensors
)
for
t
,
st
in
zip
(
tensors
,
flat_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
unmerged_tensors
)
==
1
unmerged_tensors
=
unmerged_tensors
[
0
]
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
torch
.
equal
(
t
,
tl
)
def
test_flatten_zero_param_uneven
()
->
None
:
redist_meta
=
ParamRedistMeta
(
4
,
1
,
zero_start_dp_rank
=
1
,
zero_offsets
=
[
0
,
13
])
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
list
(
orig_tensor
.
reshape
(
-
1
).
split
([
13
,
3
]))
flat_tensors
=
flatten_zero_param
(
orig_tensor
,
redist_meta
)
assert
flat_tensors
[
0
].
size
(
0
)
==
0
and
flat_tensors
[
-
1
].
size
(
0
)
==
0
flat_tensors
=
flat_tensors
[
1
:
-
1
]
assert
len
(
tensors
)
==
len
(
flat_tensors
)
for
t
,
st
in
zip
(
tensors
,
flat_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
unmerged_tensors
)
==
1
unmerged_tensors
=
unmerged_tensors
[
0
]
assert
unmerged_tensors
[
0
].
size
(
0
)
==
0
and
unmerged_tensors
[
-
1
].
size
(
0
)
==
0
unmerged_tensors
=
unmerged_tensors
[
1
:
-
1
]
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
torch
.
equal
(
t
,
tl
)
def
test_split_tp_param_1d_row
()
->
None
:
redist_meta
=
ParamRedistMeta
(
1
,
4
,
tp_shard_dims
=
[
0
],
tp_num_parts
=
[
4
])
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
[
t
.
contiguous
()
for
t
in
orig_tensor
.
chunk
(
4
,
0
)]
split_tensors
=
split_tp_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
split_tensors
)
for
t
,
st
in
zip
(
tensors
,
split_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
len
(
tl
)
==
1
assert
torch
.
equal
(
t
,
tl
[
0
])
def
test_split_tp_param_1d_col
()
->
None
:
redist_meta
=
ParamRedistMeta
(
1
,
4
,
tp_shard_dims
=
[
1
],
tp_num_parts
=
[
4
])
orig_tensor
=
torch
.
rand
(
4
,
4
)
tensors
=
[
t
.
contiguous
()
for
t
in
orig_tensor
.
chunk
(
4
,
1
)]
split_tensors
=
split_tp_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
split_tensors
)
for
t
,
st
in
zip
(
tensors
,
split_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
len
(
tl
)
==
1
assert
torch
.
equal
(
t
,
tl
[
0
])
def
test_split_tp_param_2d
()
->
None
:
redist_meta
=
ParamRedistMeta
(
1
,
6
,
tp_shard_dims
=
[
0
,
1
],
tp_num_parts
=
[
2
,
3
])
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
t
.
contiguous
()
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)]
split_tensors
=
split_tp_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
split_tensors
)
for
t
,
st
in
zip
(
tensors
,
split_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
len
(
tl
)
==
1
assert
torch
.
equal
(
t
,
tl
[
0
])
def
test_split_tp_param_2d_reverse
()
->
None
:
redist_meta
=
ParamRedistMeta
(
1
,
6
,
tp_shard_dims
=
[
1
,
0
],
tp_num_parts
=
[
3
,
2
])
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
t
.
contiguous
()
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)]
split_tensors
=
split_tp_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
split_tensors
)
for
t
,
st
in
zip
(
tensors
,
split_tensors
):
assert
torch
.
equal
(
t
,
st
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
tensors
)
==
len
(
unmerged_tensors
)
for
t
,
tl
in
zip
(
tensors
,
unmerged_tensors
):
assert
len
(
tl
)
==
1
assert
torch
.
equal
(
t
,
tl
[
0
])
def
test_unmerge_param_hybrid
()
->
None
:
redist_meta
=
ParamRedistMeta
(
2
,
6
,
tp_shard_dims
=
[
1
,
0
],
tp_num_parts
=
[
3
,
2
],
zero_start_dp_rank
=
0
,
zero_offsets
=
[
0
,
1
])
orig_tensor
=
torch
.
rand
(
4
,
6
)
tensors
=
[
chunk
for
tl
in
orig_tensor
.
chunk
(
2
,
0
)
for
t
in
tl
.
chunk
(
3
,
1
)
for
chunk
in
t
.
contiguous
().
reshape
(
-
1
).
split
([
1
,
3
])
]
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
unmerged_tensors
)
==
6
and
len
(
unmerged_tensors
[
0
])
==
2
for
tp_rank
in
range
(
6
):
for
dp_rank
in
range
(
2
):
assert
torch
.
equal
(
tensors
[
tp_rank
*
2
+
dp_rank
],
unmerged_tensors
[
tp_rank
][
dp_rank
])
def
test_unmerge_param_dummy
()
->
None
:
redist_meta
=
ParamRedistMeta
(
1
,
1
)
orig_tensor
=
torch
.
rand
(
4
,
6
)
unmerged_tensors
=
unmerge_param
(
orig_tensor
,
redist_meta
)
assert
len
(
unmerged_tensors
)
==
1
and
len
(
unmerged_tensors
[
0
])
==
1
assert
torch
.
equal
(
orig_tensor
,
unmerged_tensors
[
0
][
0
])
if
__name__
==
'__main__'
:
test_flatten_zero_param_even
()
test_flatten_zero_param_uneven
()
test_split_tp_param_1d_row
()
test_split_tp_param_1d_col
()
test_split_tp_param_2d
()
test_split_tp_param_2d_reverse
()
test_unmerge_param_hybrid
()
test_unmerge_param_dummy
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment