Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
7531c6bf
Unverified
Commit
7531c6bf
authored
Mar 26, 2021
by
Stas Bekman
Committed by
GitHub
Mar 26, 2021
Browse files
full fp32 weights reconstruction for zero 2+3 (#892)
parent
7f03282c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
1 deletion
+177
-1
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+26
-1
deepspeed/utils/zero_to_fp32.py
deepspeed/utils/zero_to_fp32.py
+151
-0
No files found.
deepspeed/runtime/engine.py
View file @
7531c6bf
...
...
@@ -3,10 +3,13 @@ Copyright 2019 The Microsoft DeepSpeed Team
'''
import
os
import
stat
import
torch
import
warnings
import
hashlib
import
torch.distributed
as
dist
from
collections
import
OrderedDict
from
shutil
import
copyfile
from
torch.nn.modules
import
Module
from
torch.distributed.distributed_c10d
import
_get_global_rank
...
...
@@ -1684,8 +1687,30 @@ class DeepSpeedEngine(Module):
torch
.
save
(
state
,
save_path
)
self
.
_curr_save_path
=
None
def
_get_param_shapes
(
self
):
param_shapes
=
OrderedDict
()
for
name
,
param
in
self
.
module
.
named_parameters
():
param_shapes
[
name
]
=
param
.
ds_shape
if
hasattr
(
param
,
"ds_shape"
)
else
param
.
shape
# print(f"saving param {name} {param_shapes[name]}")
return
param_shapes
def
_copy_recovery_script
(
self
,
save_path
):
base_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
script
=
"zero_to_fp32.py"
src
=
os
.
path
.
join
(
base_dir
,
"utils"
,
script
)
dst
=
os
.
path
.
join
(
save_path
,
script
)
logger
.
info
(
f
"creating recovery script
{
dst
}
"
)
copyfile
(
src
,
dst
)
# make executable
os
.
chmod
(
dst
,
os
.
stat
(
dst
).
st_mode
|
stat
.
S_IEXEC
)
def
_save_zero_checkpoint
(
self
,
save_path
,
tag
):
zero_checkpoint_name
=
self
.
_get_zero_ckpt_name
(
save_path
,
tag
)
zero_sd
=
{
'optimizer_state_dict'
:
self
.
optimizer
.
state_dict
()}
zero_sd
=
dict
(
optimizer_state_dict
=
self
.
optimizer
.
state_dict
(),
param_shapes
=
self
.
_get_param_shapes
(),
)
torch
.
save
(
zero_sd
,
zero_checkpoint_name
)
self
.
_copy_recovery_script
(
save_path
)
logger
.
info
(
'zero checkpoint saved {}'
.
format
(
zero_checkpoint_name
))
deepspeed/utils/zero_to_fp32.py
0 → 100644
View file @
7531c6bf
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py global_step1 pytorch_model.bin
import
argparse
import
torch
import
glob
import
os
from
collections
import
OrderedDict
import
deepspeed
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
def
get_optim_files
(
checkpoint_dir
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files
=
sorted
(
glob
.
glob
(
f
"
{
checkpoint_dir
}
/*_optim_states.pt"
))
if
len
(
optim_files
)
==
0
:
raise
FileNotFoundError
(
f
"can't find '*_optim_states.pt' files in directory '
{
checkpoint_dir
}
'"
)
return
optim_files
def
parse_optim_states
(
files
):
state_dicts
=
[]
for
f
in
files
:
state_dicts
.
append
(
torch
.
load
(
f
))
if
not
"zero_stage"
in
state_dicts
[
0
][
'optimizer_state_dict'
]:
raise
ValueError
(
f
"non zero checkpoint"
)
zero_stage
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"zero_stage"
]
# the groups are named differently in each stage
if
zero_stage
==
2
:
fp32_groups_key
=
"single_partition_of_fp32_groups"
elif
zero_stage
==
3
:
fp32_groups_key
=
"fp32_flat_groups"
else
:
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
param_shapes
=
state_dicts
[
0
][
"param_shapes"
]
fp32_flat_groups
=
[
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
][
0
]
for
i
in
range
(
len
(
state_dicts
))
]
world_size
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"partition_count"
]
return
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
def
zero3_partitioned_param_info
(
unpartitioned_numel
,
world_size
):
remainder
=
unpartitioned_numel
%
world_size
padding_numel
=
(
world_size
-
remainder
)
if
remainder
else
0
partitioned_numel
=
int
(
unpartitioned_numel
/
world_size
)
return
partitioned_numel
,
padding_numel
def
convert_zero_chkpt_to_fp32_consolid_state_dict
(
checkpoint_dir
,
output_file
):
"""
Convert zero 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be
loaded with ``torch.load(file)`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the deepspeed checkpoint folder
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
"""
print
(
f
"Processing zero checkpoint '
{
checkpoint_dir
}
'"
)
optim_files
=
get_optim_files
(
checkpoint_dir
)
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
)
print
(
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
# flat buffer - no need to deal with padding since if there is any it will be only in the tail
# of the last partition so there it will be just left out
#
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any
if
zero_stage
==
2
:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector
=
torch
.
cat
(
fp32_flat_groups
,
0
)
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
state_dict
=
OrderedDict
()
offset
=
0
total_numel
=
0
for
name
,
shape
in
param_shapes
.
items
():
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
if
zero_stage
==
2
:
# print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
unpartitioned_numel
elif
zero_stage
==
3
:
partitioned_numel
,
partitioned_padding_numel
=
zero3_partitioned_param_info
(
unpartitioned_numel
,
world_size
)
# print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}")
# XXX: memory usage doubles here (zero3)
state_dict
[
name
]
=
torch
.
cat
(
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
offset
,
partitioned_numel
)
for
i
in
range
(
world_size
)),
0
).
view
(
shape
)
offset
+=
partitioned_numel
+
partitioned_padding_numel
# the job is done
print
(
f
"Saving fp32 state dict to
{
output_file
}
(total_numel=
{
total_numel
}
)"
)
torch
.
save
(
state_dict
,
output_file
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"checkpoint_dir"
,
type
=
str
,
help
=
"path to the deepspeed checkpoint folder, e.g., path/checkpoint-1/global_step1"
)
parser
.
add_argument
(
"output_file"
,
type
=
str
,
help
=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-1/pytorch_model.bin)"
)
args
=
parser
.
parse_args
()
convert_zero_chkpt_to_fp32_consolid_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment