Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1df591b0
Commit
1df591b0
authored
Feb 08, 2024
by
Jennifer
Browse files
updates zero_to_fp32.py for new deepspeed version and import_weight bugfix
parent
bb3f51e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
258 additions
and
124 deletions
+258
-124
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+3
-2
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+255
-122
No files found.
openfold/utils/import_weights.py
View file @
1df591b0
...
@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict):
...
@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict):
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
# Add prefix for template modules
# Add prefix for template modules
if
new_key
.
startswith
(
'template'
):
subheader
=
re
.
search
(
'(?<=model.).*$'
,
new_key
).
group
()
new_key
=
f
'template_embedder.
{
new_key
}
'
if
subheader
.
startswith
(
'template'
):
new_key
=
f
'model.template_embedder.
{
subheader
}
'
converted_state_dict
[
new_key
]
=
value
converted_state_dict
[
new_key
]
=
value
...
...
scripts/zero_to_fp32.py
View file @
1df591b0
#!/usr/bin/env python
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
# application.
...
@@ -12,13 +17,27 @@ import torch
...
@@ -12,13 +17,27 @@ import torch
import
glob
import
glob
import
math
import
math
import
os
import
os
from
collections
import
OrderedDict
import
re
import
re
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
# DeepSpeed data structures it has to be available in the current python environment.
import
deepspeed
from
deepspeed.utils
import
logger
from
deepspeed.utils
import
logger
from
deepspeed.checkpoint.constants
import
(
DS_VERSION
,
OPTIMIZER_STATE_DICT
,
SINGLE_PARTITION_OF_FP32_GROUPS
,
FP32_FLAT_GROUPS
,
ZERO_STAGE
,
PARTITION_COUNT
,
PARAM_SHAPES
,
BUFFER_NAMES
,
FROZEN_PARAM_SHAPES
,
FROZEN_PARAM_FRAGMENTS
)
@
dataclass
class
zero_model_state
:
buffers
:
dict
()
param_shapes
:
dict
()
shared_params
:
list
ds_version
:
int
frozen_param_shapes
:
dict
()
frozen_param_fragments
:
dict
()
debug
=
0
debug
=
0
...
@@ -26,12 +45,25 @@ debug = 0
...
@@ -26,12 +45,25 @@ debug = 0
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
def
atoi
(
text
):
return
int
(
text
)
if
text
.
isdigit
()
else
text
def
natural_keys
(
text
):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return
[
atoi
(
c
)
for
c
in
re
.
split
(
r
'(\d+)'
,
text
)]
def
get_model_state_file
(
checkpoint_dir
,
zero_stage
):
def
get_model_state_file
(
checkpoint_dir
,
zero_stage
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
# there should be only one file
# there should be only one file
if
zero_stage
=
=
2
:
if
zero_stage
<
=
2
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"mp_rank_00_model_states.pt"
)
file
=
os
.
path
.
join
(
checkpoint_dir
,
"mp_rank_00_model_states.pt"
)
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"zero_pp_rank_0_mp_rank_00_model_states.pt"
)
file
=
os
.
path
.
join
(
checkpoint_dir
,
"zero_pp_rank_0_mp_rank_00_model_states.pt"
)
...
@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
...
@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return
file
return
file
def
get_
optim
_files
(
checkpoint_dir
):
def
get_
checkpoint
_files
(
checkpoint_dir
,
glob_pattern
):
# XXX: need to test that this simple glob rule works for multi-node setup too
# XXX: need to test that this simple glob rule works for multi-node setup too
optim
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
"*_optim_states.pt"
))
)
ckpt
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
glob_pattern
)),
key
=
natural_keys
)
if
len
(
optim_files
)
==
0
:
if
len
(
ckpt_files
)
==
0
:
raise
FileNotFoundError
(
raise
FileNotFoundError
(
f
"can't find
{
glob_pattern
}
files in directory '
{
checkpoint_dir
}
'"
)
f
"can't find '*_optim_states.pt' files in directory '
{
checkpoint_dir
}
'"
)
return
optim
_files
return
ckpt
_files
def
parse_model_state
(
file
):
def
get_optim_files
(
checkpoint_dir
):
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
return
get_checkpoint_files
(
checkpoint_dir
,
"*_optim_states.pt"
)
if
"buffer_names"
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
"buffer_names"
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
def
get_model_state_files
(
checkpoint_dir
):
buffers
=
{
return
get_checkpoint_files
(
checkpoint_dir
,
"*_model_states.pt"
)
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
def
parse_model_states
(
files
):
}
zero_model_states
=
[]
return
buffers
for
file
in
files
:
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
if
BUFFER_NAMES
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
BUFFER_NAMES
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers
=
{
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
}
param_shapes
=
state_dict
[
PARAM_SHAPES
]
# collect parameters that are included in param_shapes
param_names
=
[]
for
s
in
param_shapes
:
for
name
in
s
.
keys
():
param_names
.
append
(
name
)
# update with frozen parameters
frozen_param_shapes
=
state_dict
.
get
(
FROZEN_PARAM_SHAPES
,
None
)
if
frozen_param_shapes
is
not
None
:
if
debug
:
print
(
f
"Found frozen_param_shapes:
{
frozen_param_shapes
}
"
)
param_names
+=
list
(
frozen_param_shapes
.
keys
())
# handle shared params
shared_params
=
[[
k
,
v
]
for
k
,
v
in
state_dict
[
"shared_params"
].
items
()]
ds_version
=
state_dict
.
get
(
DS_VERSION
,
None
)
frozen_param_fragments
=
state_dict
.
get
(
FROZEN_PARAM_FRAGMENTS
,
None
)
z_model_state
=
zero_model_state
(
buffers
=
buffers
,
param_shapes
=
param_shapes
,
shared_params
=
shared_params
,
ds_version
=
ds_version
,
frozen_param_shapes
=
frozen_param_shapes
,
frozen_param_fragments
=
frozen_param_fragments
)
zero_model_states
.
append
(
z_model_state
)
return
zero_model_states
def
parse_optim_states
(
files
,
ds_checkpoint_dir
):
def
parse_optim_states
(
files
,
ds_checkpoint_dir
):
...
@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files
=
len
(
files
)
total_files
=
len
(
files
)
state_dicts
=
[]
state_dicts
=
[]
for
f
in
files
:
for
f
in
files
:
state_dicts
.
append
(
torch
.
load
(
f
,
map_location
=
device
))
state_dict
=
torch
.
load
(
f
,
map_location
=
device
)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict
[
"optimizer_state_dict"
].
pop
(
"optimizer_state_dict"
,
None
)
state_dicts
.
append
(
state_dict
)
if
not
"zero_stage"
in
state_dicts
[
0
][
'optimizer_state_dict'
]:
if
not
ZERO_STAGE
in
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
]:
raise
ValueError
(
f
"
{
files
[
0
]
}
is not a zero checkpoint"
)
raise
ValueError
(
f
"
{
files
[
0
]
}
is not a zero checkpoint"
)
zero_stage
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"zero_stage"
]
zero_stage
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
ZERO_STAGE
]
world_size
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"partition_count"
]
world_size
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
PARTITION_COUNT
]
param_shapes
=
state_dicts
[
0
][
"param_shapes"
]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
# use the max of the partition_count to get the dp world_size.
...
@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
)
)
# the groups are named differently in each stage
# the groups are named differently in each stage
if
zero_stage
=
=
2
:
if
zero_stage
<
=
2
:
fp32_groups_key
=
"single_partition_of_fp32_groups"
fp32_groups_key
=
SINGLE_PARTITION_OF_FP32_GROUPS
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
fp32_groups_key
=
"fp32_flat_groups"
fp32_groups_key
=
FP32_FLAT_GROUPS
else
:
else
:
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
if
zero_stage
==
2
:
if
zero_stage
<=
2
:
fp32_flat_groups
=
[
fp32_flat_groups
=
[
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))]
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))
]
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
# if there is more than one param group, there will be multiple flattened tensors - one
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
# flattened tensor per group - for simplicity merge them into a single tensor
...
@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups
=
[
fp32_flat_groups
=
[
torch
.
cat
(
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
],
torch
.
cat
(
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
],
0
)
for
i
in
range
(
len
(
state_dicts
))
0
)
for
i
in
range
(
len
(
state_dicts
))
]
]
return
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
return
zero_stage
,
world_size
,
fp32_flat_groups
def
_get_fp32_state_dict_from_zero_checkpoint
(
ds_checkpoint_dir
):
def
_get_fp32_state_dict_from_zero_checkpoint
(
ds_checkpoint_dir
):
...
@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
...
@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print
(
f
"Processing zero checkpoint '
{
ds_checkpoint_dir
}
'"
)
print
(
f
"Processing zero checkpoint '
{
ds_checkpoint_dir
}
'"
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
zero_stage
,
world_size
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
print
(
print
(
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
model_files
=
get_model_state_files
(
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
buffers
=
parse_model_state
(
model_file
)
zero_model_states
=
parse_model_states
(
model_files
)
print
(
f
'Parsing checkpoint created by deepspeed==
{
zero_model_states
[
0
].
ds_version
}
'
)
if
zero_stage
==
2
:
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
if
zero_stage
<=
2
:
param_shapes
,
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
fp32_flat_groups
,
buffers
)
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
param_shapes
,
fp32_flat_groups
,
buffers
)
def
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
):
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
return
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
frozen_param_fragments
=
zero_model_states
[
0
].
frozen_param_fragments
if
debug
:
num_elem
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
print
(
f
'rank 0:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
frozen_param_fragments
.
values
()])
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
state_dict
[
name
]
=
frozen_param_fragments
[
name
]
if
debug
:
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
def
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
,
param_shapes
=
zero_model_states
[
0
].
param_shapes
fp32_flat_groups
,
buffers
):
# Reconstruction protocol:
# Reconstruction protocol:
#
#
...
@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if
debug
:
if
debug
:
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
for
j
in
range
(
len
(
fp32_flat_groups
[
0
])):
for
j
in
range
(
len
(
fp32_flat_groups
[
0
])):
print
(
f
"
fp32_flat_groups
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
# XXX: memory usage doubles here (zero2)
# XXX: memory usage doubles here (zero2)
num_param_groups
=
len
(
fp32_flat_groups
[
0
])
num_param_groups
=
len
(
fp32_flat_groups
[
0
])
...
@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions
=
[
sd
[
i
]
for
sd
in
fp32_flat_groups
]
merged_partitions
=
[
sd
[
i
]
for
sd
in
fp32_flat_groups
]
full_single_fp32_vector
=
torch
.
cat
(
merged_partitions
,
0
)
full_single_fp32_vector
=
torch
.
cat
(
merged_partitions
,
0
)
merged_single_partition_of_fp32_groups
.
append
(
full_single_fp32_vector
)
merged_single_partition_of_fp32_groups
.
append
(
full_single_fp32_vector
)
avail_numel
=
sum
([
avail_numel
=
sum
(
full_single_fp32_vector
.
numel
()
[
full_single_fp32_vector
.
numel
()
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
if
debug
:
if
debug
:
wanted_params
=
sum
([
len
(
shapes
)
for
shapes
in
param_shapes
])
wanted_params
=
sum
([
len
(
shapes
)
for
shapes
in
param_shapes
])
wanted_numel
=
sum
(
wanted_numel
=
sum
([
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
[
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
# not asserting if there is a mismatch due to possible padding
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
# params
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
# out-of-core computing solution
...
@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params
+=
1
total_params
+=
1
if
debug
:
if
debug
:
print
(
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
)
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
unpartitioned_numel
offset
+=
unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...
@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check
# Sanity check
if
offset
!=
avail_numel
:
if
offset
!=
avail_numel
:
raise
ValueError
(
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
print
(
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
)
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
return
state_dict
...
@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
...
@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return
partitioned_numel
,
padding_numel
return
partitioned_numel
,
padding_numel
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
def
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
):
param_shapes
,
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
fp32_flat_groups
,
return
buffers
):
if
debug
:
for
i
in
range
(
world_size
):
num_elem
=
sum
(
s
.
numel
()
for
s
in
zero_model_states
[
i
].
frozen_param_fragments
.
values
())
print
(
f
'rank
{
i
}
:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
zero_model_states
[
0
].
frozen_param_fragments
.
values
()])
*
world_size
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
zero_model_states
[
0
].
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
param_frags
=
tuple
(
model_state
.
frozen_param_fragments
[
name
]
for
model_state
in
zero_model_states
)
state_dict
[
name
]
=
torch
.
cat
(
param_frags
,
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
partitioned_numel
,
partitioned_padding_numel
=
zero3_partitioned_param_info
(
unpartitioned_numel
,
world_size
)
if
debug
:
print
(
f
"Frozen params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
=
zero_model_states
[
0
].
param_shapes
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
# param, re-consolidating each param, while dealing with padding if any
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# merge list of dicts, preserving order
# merge list of dicts, preserving order
param_shapes
=
{
k
:
v
for
d
in
param_shapes
for
k
,
v
in
d
.
items
()}
param_shapes
=
{
k
:
v
for
d
in
param_shapes
for
k
,
v
in
d
.
items
()}
if
debug
:
if
debug
:
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
print
(
f
"
fp32_flat_groups
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
wanted_params
=
len
(
param_shapes
)
wanted_params
=
len
(
param_shapes
)
wanted_numel
=
sum
(
shape
.
numel
()
for
shape
in
param_shapes
.
values
())
wanted_numel
=
sum
(
shape
.
numel
()
for
shape
in
param_shapes
.
values
())
# not asserting if there is a mismatch due to possible padding
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
print
(
f
"Trainable params: Have
{
avail_numel
}
numels to process."
)
print
(
f
"Trainable params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
# params
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...
@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
...
@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if
debug
:
if
debug
:
print
(
print
(
f
"
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
f
"
Trainable params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
)
# XXX: memory usage doubles here
# XXX: memory usage doubles here
state_dict
[
name
]
=
torch
.
cat
(
state_dict
[
name
]
=
torch
.
cat
(
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
offset
,
partitioned_numel
)
for
i
in
range
(
world_size
)),
offset
,
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
partitioned_numel
)
for
i
in
range
(
world_size
)),
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
partitioned_numel
offset
+=
partitioned_numel
offset
*=
world_size
offset
*=
world_size
# Sanity check
# Sanity check
if
offset
!=
avail_numel
:
if
offset
!=
avail_numel
:
raise
ValueError
(
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
print
(
f
"Reconstructed Trainable fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
)
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
return
state_dict
...
@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
...
@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"checkpoint_dir"
,
"checkpoint_dir"
,
type
=
str
,
type
=
str
,
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
parser
.
add_argument
(
parser
.
add_argument
(
"output_file"
,
"output_file"
,
type
=
str
,
type
=
str
,
help
=
help
=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
parser
.
add_argument
(
"-t"
,
)
"--tag"
,
type
=
str
,
default
=
None
,
help
=
"checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
action
=
'store_true'
,
help
=
"enable debug"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
action
=
'store_true'
,
help
=
"enable debug"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
debug
=
args
.
debug
debug
=
args
.
debug
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
)
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
,
tag
=
args
.
tag
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment