Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9776b696
Commit
9776b696
authored
Feb 21, 2024
by
jnwei
Browse files
Merge weight-loading changes into setup-improvements
parents
9f346d35
ddfccd56
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
366 additions
and
128 deletions
+366
-128
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+8
-4
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+95
-0
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+255
-122
train_openfold.py
train_openfold.py
+8
-2
No files found.
openfold/utils/import_weights.py
View file @
9776b696
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
re
import
logging
from
enum
import
Enum
from
dataclasses
import
dataclass
from
functools
import
partial
...
...
@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
template_emb_re
=
re
.
compile
(
r
"^((module\.)?(model\.)?)(template(?!_embedder).*)"
)
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
(
1
)],
key
)
# Add prefix for template modules
if
new_key
.
startswith
(
'template'
):
new_key
=
f
'template_embedder.
{
new_key
}
'
# Add prefix for template layers
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
if
template_match
:
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
0 → 100755
View file @
9776b696
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
logging
import
argparse
import
os
import
shutil
import
torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
convert_v1_to_v2_weights
(
args
):
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
# A DeepSpeed checkpoint
logging
.
info
(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'module'
latest_path
=
os
.
path
.
join
(
checkpoint_path
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
ds_checkpoint_dir
=
os
.
path
.
join
(
checkpoint_path
,
tag
)
model_output_path
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
tag
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
else
:
# A Pytorch Lightning checkpoint
logging
.
info
(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'state_dict'
model_output_path
=
args
.
output_ckpt_path
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
param_shapes
=
convert_deprecated_v1_keys
(
model_dict
[
'param_shapes'
][
0
])
model_dict
[
'param_shapes'
]
=
[
param_shapes
]
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
model_file
))
for
optim_file
in
optim_files
:
optim_dict
=
torch
.
load
(
optim_file
)
new_optim_dict
=
optim_dict
.
copy
()
new_optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
]
=
convert_deprecated_v1_keys
(
optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
])
out_optim_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
optim_file
))
torch
.
save
(
new_optim_dict
,
out_optim_fname
)
else
:
out_fname
=
model_output_path
torch
.
save
(
model_dict
,
out_fname
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_ckpt_path"
,
type
=
str
)
parser
.
add_argument
(
"output_ckpt_path"
,
type
=
str
)
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
scripts/zero_to_fp32.py
View file @
9776b696
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
...
...
@@ -12,13 +17,27 @@ import torch
import
glob
import
math
import
os
from
collections
import
OrderedDict
import
re
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
import
deepspeed
from
deepspeed.utils
import
logger
from
deepspeed.checkpoint.constants
import
(
DS_VERSION
,
OPTIMIZER_STATE_DICT
,
SINGLE_PARTITION_OF_FP32_GROUPS
,
FP32_FLAT_GROUPS
,
ZERO_STAGE
,
PARTITION_COUNT
,
PARAM_SHAPES
,
BUFFER_NAMES
,
FROZEN_PARAM_SHAPES
,
FROZEN_PARAM_FRAGMENTS
)
@
dataclass
class
zero_model_state
:
buffers
:
dict
()
param_shapes
:
dict
()
shared_params
:
list
ds_version
:
int
frozen_param_shapes
:
dict
()
frozen_param_fragments
:
dict
()
debug
=
0
...
...
@@ -26,12 +45,25 @@ debug = 0
device
=
torch
.
device
(
'cpu'
)
def
atoi
(
text
):
return
int
(
text
)
if
text
.
isdigit
()
else
text
def
natural_keys
(
text
):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return
[
atoi
(
c
)
for
c
in
re
.
split
(
r
'(\d+)'
,
text
)]
def
get_model_state_file
(
checkpoint_dir
,
zero_stage
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
# there should be only one file
if
zero_stage
=
=
2
:
if
zero_stage
<
=
2
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"mp_rank_00_model_states.pt"
)
elif
zero_stage
==
3
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"zero_pp_rank_0_mp_rank_00_model_states.pt"
)
...
...
@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return
file
def
get_
optim
_files
(
checkpoint_dir
):
def
get_
checkpoint
_files
(
checkpoint_dir
,
glob_pattern
):
# XXX: need to test that this simple glob rule works for multi-node setup too
optim
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
"*_optim_states.pt"
))
)
ckpt
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
glob_pattern
)),
key
=
natural_keys
)
if
len
(
optim_files
)
==
0
:
raise
FileNotFoundError
(
f
"can't find '*_optim_states.pt' files in directory '
{
checkpoint_dir
}
'"
)
if
len
(
ckpt_files
)
==
0
:
raise
FileNotFoundError
(
f
"can't find
{
glob_pattern
}
files in directory '
{
checkpoint_dir
}
'"
)
return
optim
_files
return
ckpt
_files
def
parse_model_state
(
file
):
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
def
get_optim_files
(
checkpoint_dir
):
return
get_checkpoint_files
(
checkpoint_dir
,
"*_optim_states.pt"
)
if
"buffer_names"
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
"buffer_names"
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers
=
{
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
}
return
buffers
def
get_model_state_files
(
checkpoint_dir
):
return
get_checkpoint_files
(
checkpoint_dir
,
"*_model_states.pt"
)
def
parse_model_states
(
files
):
zero_model_states
=
[]
for
file
in
files
:
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
if
BUFFER_NAMES
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
BUFFER_NAMES
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers
=
{
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
}
param_shapes
=
state_dict
[
PARAM_SHAPES
]
# collect parameters that are included in param_shapes
param_names
=
[]
for
s
in
param_shapes
:
for
name
in
s
.
keys
():
param_names
.
append
(
name
)
# update with frozen parameters
frozen_param_shapes
=
state_dict
.
get
(
FROZEN_PARAM_SHAPES
,
None
)
if
frozen_param_shapes
is
not
None
:
if
debug
:
print
(
f
"Found frozen_param_shapes:
{
frozen_param_shapes
}
"
)
param_names
+=
list
(
frozen_param_shapes
.
keys
())
# handle shared params
shared_params
=
[[
k
,
v
]
for
k
,
v
in
state_dict
[
"shared_params"
].
items
()]
ds_version
=
state_dict
.
get
(
DS_VERSION
,
None
)
frozen_param_fragments
=
state_dict
.
get
(
FROZEN_PARAM_FRAGMENTS
,
None
)
z_model_state
=
zero_model_state
(
buffers
=
buffers
,
param_shapes
=
param_shapes
,
shared_params
=
shared_params
,
ds_version
=
ds_version
,
frozen_param_shapes
=
frozen_param_shapes
,
frozen_param_fragments
=
frozen_param_fragments
)
zero_model_states
.
append
(
z_model_state
)
return
zero_model_states
def
parse_optim_states
(
files
,
ds_checkpoint_dir
):
...
...
@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files
=
len
(
files
)
state_dicts
=
[]
for
f
in
files
:
state_dicts
.
append
(
torch
.
load
(
f
,
map_location
=
device
))
state_dict
=
torch
.
load
(
f
,
map_location
=
device
)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict
[
"optimizer_state_dict"
].
pop
(
"optimizer_state_dict"
,
None
)
state_dicts
.
append
(
state_dict
)
if
not
"zero_stage"
in
state_dicts
[
0
][
'optimizer_state_dict'
]:
if
not
ZERO_STAGE
in
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
]:
raise
ValueError
(
f
"
{
files
[
0
]
}
is not a zero checkpoint"
)
zero_stage
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"zero_stage"
]
world_size
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"partition_count"
]
param_shapes
=
state_dicts
[
0
][
"param_shapes"
]
zero_stage
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
ZERO_STAGE
]
world_size
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
PARTITION_COUNT
]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
...
...
@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
)
# the groups are named differently in each stage
if
zero_stage
=
=
2
:
fp32_groups_key
=
"single_partition_of_fp32_groups"
if
zero_stage
<
=
2
:
fp32_groups_key
=
SINGLE_PARTITION_OF_FP32_GROUPS
elif
zero_stage
==
3
:
fp32_groups_key
=
"fp32_flat_groups"
fp32_groups_key
=
FP32_FLAT_GROUPS
else
:
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
if
zero_stage
==
2
:
fp32_flat_groups
=
[
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))
]
if
zero_stage
<=
2
:
fp32_flat_groups
=
[
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))]
elif
zero_stage
==
3
:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
...
...
@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups
=
[
torch
.
cat
(
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
],
0
)
for
i
in
range
(
len
(
state_dicts
))
torch
.
cat
(
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
],
0
)
for
i
in
range
(
len
(
state_dicts
))
]
return
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
return
zero_stage
,
world_size
,
fp32_flat_groups
def
_get_fp32_state_dict_from_zero_checkpoint
(
ds_checkpoint_dir
):
...
...
@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print
(
f
"Processing zero checkpoint '
{
ds_checkpoint_dir
}
'"
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
print
(
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
buffers
=
parse_model_state
(
model_file
)
if
zero_stage
==
2
:
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
param_shapes
,
fp32_flat_groups
,
buffers
)
zero_stage
,
world_size
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
print
(
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
model_files
=
get_model_state_files
(
ds_checkpoint_dir
)
zero_model_states
=
parse_model_states
(
model_files
)
print
(
f
'Parsing checkpoint created by deepspeed==
{
zero_model_states
[
0
].
ds_version
}
'
)
if
zero_stage
<=
2
:
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
elif
zero_stage
==
3
:
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
param_shapes
,
fp32_flat_groups
,
buffers
)
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
def
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
):
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
return
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
frozen_param_fragments
=
zero_model_states
[
0
].
frozen_param_fragments
if
debug
:
num_elem
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
print
(
f
'rank 0:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
frozen_param_fragments
.
values
()])
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
state_dict
[
name
]
=
frozen_param_fragments
[
name
]
if
debug
:
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
param_shapes
,
fp32_flat_groups
,
buffers
):
def
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
=
zero_model_states
[
0
].
param_shapes
# Reconstruction protocol:
#
...
...
@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if
debug
:
for
i
in
range
(
world_size
):
for
j
in
range
(
len
(
fp32_flat_groups
[
0
])):
print
(
f
"
fp32_flat_groups
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
# XXX: memory usage doubles here (zero2)
num_param_groups
=
len
(
fp32_flat_groups
[
0
])
...
...
@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions
=
[
sd
[
i
]
for
sd
in
fp32_flat_groups
]
full_single_fp32_vector
=
torch
.
cat
(
merged_partitions
,
0
)
merged_single_partition_of_fp32_groups
.
append
(
full_single_fp32_vector
)
avail_numel
=
sum
([
full_single_fp32_vector
.
numel
()
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
avail_numel
=
sum
(
[
full_single_fp32_vector
.
numel
()
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
if
debug
:
wanted_params
=
sum
([
len
(
shapes
)
for
shapes
in
param_shapes
])
wanted_numel
=
sum
(
[
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
wanted_numel
=
sum
([
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
...
...
@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params
+=
1
if
debug
:
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...
...
@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check
if
offset
!=
avail_numel
:
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
print
(
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
)
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
...
...
@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return
partitioned_numel
,
padding_numel
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
param_shapes
,
fp32_flat_groups
,
buffers
):
def
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
):
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
return
if
debug
:
for
i
in
range
(
world_size
):
num_elem
=
sum
(
s
.
numel
()
for
s
in
zero_model_states
[
i
].
frozen_param_fragments
.
values
())
print
(
f
'rank
{
i
}
:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
zero_model_states
[
0
].
frozen_param_fragments
.
values
()])
*
world_size
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
zero_model_states
[
0
].
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
param_frags
=
tuple
(
model_state
.
frozen_param_fragments
[
name
]
for
model_state
in
zero_model_states
)
state_dict
[
name
]
=
torch
.
cat
(
param_frags
,
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
partitioned_numel
,
partitioned_padding_numel
=
zero3_partitioned_param_info
(
unpartitioned_numel
,
world_size
)
if
debug
:
print
(
f
"Frozen params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
=
zero_model_states
[
0
].
param_shapes
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# merge list of dicts, preserving order
param_shapes
=
{
k
:
v
for
d
in
param_shapes
for
k
,
v
in
d
.
items
()}
if
debug
:
for
i
in
range
(
world_size
):
print
(
f
"
fp32_flat_groups
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
wanted_params
=
len
(
param_shapes
)
wanted_numel
=
sum
(
shape
.
numel
()
for
shape
in
param_shapes
.
values
())
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
print
(
f
"Trainable params: Have
{
avail_numel
}
numels to process."
)
print
(
f
"Trainable params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...
...
@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if
debug
:
print
(
f
"
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
f
"
Trainable params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
# XXX: memory usage doubles here
state_dict
[
name
]
=
torch
.
cat
(
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
offset
,
partitioned_numel
)
for
i
in
range
(
world_size
)),
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
offset
,
partitioned_numel
)
for
i
in
range
(
world_size
)),
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
partitioned_numel
offset
*=
world_size
# Sanity check
if
offset
!=
avail_numel
:
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
print
(
f
"Reconstructed Trainable fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
)
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
...
...
@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"checkpoint_dir"
,
type
=
str
,
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
parser
.
add_argument
(
"checkpoint_dir"
,
type
=
str
,
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
parser
.
add_argument
(
"output_file"
,
type
=
str
,
help
=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
help
=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
parser
.
add_argument
(
"-t"
,
"--tag"
,
type
=
str
,
default
=
None
,
help
=
"checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
action
=
'store_true'
,
help
=
"enable debug"
)
args
=
parser
.
parse_args
()
debug
=
args
.
debug
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
)
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
,
tag
=
args
.
tag
)
train_openfold.py
View file @
9776b696
...
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
...
@@ -294,8 +295,13 @@ def main(args):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
if
'module'
in
sd
:
module_sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
[
'module'
].
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
module_sd
)
elif
'state_dict'
in
sd
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
[
'state_dict'
])
else
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment