Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9776b696
Commit
9776b696
authored
Feb 21, 2024
by
jnwei
Browse files
Merge weight-loading changes into setup-improvements
parents
9f346d35
ddfccd56
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
366 additions
and
128 deletions
+366
-128
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+8
-4
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+95
-0
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+255
-122
train_openfold.py
train_openfold.py
+8
-2
No files found.
openfold/utils/import_weights.py
View file @
9776b696
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
import
logging
from
enum
import
Enum
from
enum
import
Enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
...
@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
...
@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
}
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
template_emb_re
=
re
.
compile
(
r
"^((module\.)?(model\.)?)(template(?!_embedder).*)"
)
converted_state_dict
=
{}
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
(
1
)],
key
)
# Add prefix for template modules
# Add prefix for template layers
if
new_key
.
startswith
(
'template'
):
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
new_key
=
f
'template_embedder.
{
new_key
}
'
if
template_match
:
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
converted_state_dict
[
new_key
]
=
value
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
0 → 100755
View file @
9776b696
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
logging
import
argparse
import
os
import
shutil
import
torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
convert_v1_to_v2_weights
(
args
):
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
# A DeepSpeed checkpoint
logging
.
info
(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'module'
latest_path
=
os
.
path
.
join
(
checkpoint_path
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
ds_checkpoint_dir
=
os
.
path
.
join
(
checkpoint_path
,
tag
)
model_output_path
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
tag
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
else
:
# A Pytorch Lightning checkpoint
logging
.
info
(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'state_dict'
model_output_path
=
args
.
output_ckpt_path
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
param_shapes
=
convert_deprecated_v1_keys
(
model_dict
[
'param_shapes'
][
0
])
model_dict
[
'param_shapes'
]
=
[
param_shapes
]
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
model_file
))
for
optim_file
in
optim_files
:
optim_dict
=
torch
.
load
(
optim_file
)
new_optim_dict
=
optim_dict
.
copy
()
new_optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
]
=
convert_deprecated_v1_keys
(
optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
])
out_optim_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
optim_file
))
torch
.
save
(
new_optim_dict
,
out_optim_fname
)
else
:
out_fname
=
model_output_path
torch
.
save
(
model_dict
,
out_fname
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_ckpt_path"
,
type
=
str
)
parser
.
add_argument
(
"output_ckpt_path"
,
type
=
str
)
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
scripts/zero_to_fp32.py
View file @
9776b696
#!/usr/bin/env python
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
# application.
...
@@ -12,13 +17,27 @@ import torch
...
@@ -12,13 +17,27 @@ import torch
import
glob
import
glob
import
math
import
math
import
os
import
os
from
collections
import
OrderedDict
import
re
import
re
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
# DeepSpeed data structures it has to be available in the current python environment.
import
deepspeed
from
deepspeed.utils
import
logger
from
deepspeed.utils
import
logger
from
deepspeed.checkpoint.constants
import
(
DS_VERSION
,
OPTIMIZER_STATE_DICT
,
SINGLE_PARTITION_OF_FP32_GROUPS
,
FP32_FLAT_GROUPS
,
ZERO_STAGE
,
PARTITION_COUNT
,
PARAM_SHAPES
,
BUFFER_NAMES
,
FROZEN_PARAM_SHAPES
,
FROZEN_PARAM_FRAGMENTS
)
@
dataclass
class
zero_model_state
:
buffers
:
dict
()
param_shapes
:
dict
()
shared_params
:
list
ds_version
:
int
frozen_param_shapes
:
dict
()
frozen_param_fragments
:
dict
()
debug
=
0
debug
=
0
...
@@ -26,12 +45,25 @@ debug = 0
...
@@ -26,12 +45,25 @@ debug = 0
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
def
atoi
(
text
):
return
int
(
text
)
if
text
.
isdigit
()
else
text
def
natural_keys
(
text
):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return
[
atoi
(
c
)
for
c
in
re
.
split
(
r
'(\d+)'
,
text
)]
def
get_model_state_file
(
checkpoint_dir
,
zero_stage
):
def
get_model_state_file
(
checkpoint_dir
,
zero_stage
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
raise
FileNotFoundError
(
f
"Directory '
{
checkpoint_dir
}
' doesn't exist"
)
# there should be only one file
# there should be only one file
if
zero_stage
=
=
2
:
if
zero_stage
<
=
2
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"mp_rank_00_model_states.pt"
)
file
=
os
.
path
.
join
(
checkpoint_dir
,
"mp_rank_00_model_states.pt"
)
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
file
=
os
.
path
.
join
(
checkpoint_dir
,
"zero_pp_rank_0_mp_rank_00_model_states.pt"
)
file
=
os
.
path
.
join
(
checkpoint_dir
,
"zero_pp_rank_0_mp_rank_00_model_states.pt"
)
...
@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
...
@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return
file
return
file
def
get_
optim
_files
(
checkpoint_dir
):
def
get_
checkpoint
_files
(
checkpoint_dir
,
glob_pattern
):
# XXX: need to test that this simple glob rule works for multi-node setup too
# XXX: need to test that this simple glob rule works for multi-node setup too
optim
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
"*_optim_states.pt"
))
)
ckpt
_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
checkpoint_dir
,
glob_pattern
)),
key
=
natural_keys
)
if
len
(
optim_files
)
==
0
:
if
len
(
ckpt_files
)
==
0
:
raise
FileNotFoundError
(
raise
FileNotFoundError
(
f
"can't find
{
glob_pattern
}
files in directory '
{
checkpoint_dir
}
'"
)
f
"can't find '*_optim_states.pt' files in directory '
{
checkpoint_dir
}
'"
)
return
optim
_files
return
ckpt
_files
def
parse_model_state
(
file
):
def
get_optim_files
(
checkpoint_dir
):
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
return
get_checkpoint_files
(
checkpoint_dir
,
"*_optim_states.pt"
)
if
"buffer_names"
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
"buffer_names"
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
def
get_model_state_files
(
checkpoint_dir
):
buffers
=
{
return
get_checkpoint_files
(
checkpoint_dir
,
"*_model_states.pt"
)
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
def
parse_model_states
(
files
):
}
zero_model_states
=
[]
return
buffers
for
file
in
files
:
state_dict
=
torch
.
load
(
file
,
map_location
=
device
)
if
BUFFER_NAMES
not
in
state_dict
:
raise
ValueError
(
f
"
{
file
}
is not a model state checkpoint"
)
buffer_names
=
state_dict
[
BUFFER_NAMES
]
if
debug
:
print
(
"Found buffers:"
,
buffer_names
)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers
=
{
k
:
v
.
float
()
for
k
,
v
in
state_dict
[
"module"
].
items
()
if
k
in
buffer_names
}
param_shapes
=
state_dict
[
PARAM_SHAPES
]
# collect parameters that are included in param_shapes
param_names
=
[]
for
s
in
param_shapes
:
for
name
in
s
.
keys
():
param_names
.
append
(
name
)
# update with frozen parameters
frozen_param_shapes
=
state_dict
.
get
(
FROZEN_PARAM_SHAPES
,
None
)
if
frozen_param_shapes
is
not
None
:
if
debug
:
print
(
f
"Found frozen_param_shapes:
{
frozen_param_shapes
}
"
)
param_names
+=
list
(
frozen_param_shapes
.
keys
())
# handle shared params
shared_params
=
[[
k
,
v
]
for
k
,
v
in
state_dict
[
"shared_params"
].
items
()]
ds_version
=
state_dict
.
get
(
DS_VERSION
,
None
)
frozen_param_fragments
=
state_dict
.
get
(
FROZEN_PARAM_FRAGMENTS
,
None
)
z_model_state
=
zero_model_state
(
buffers
=
buffers
,
param_shapes
=
param_shapes
,
shared_params
=
shared_params
,
ds_version
=
ds_version
,
frozen_param_shapes
=
frozen_param_shapes
,
frozen_param_fragments
=
frozen_param_fragments
)
zero_model_states
.
append
(
z_model_state
)
return
zero_model_states
def
parse_optim_states
(
files
,
ds_checkpoint_dir
):
def
parse_optim_states
(
files
,
ds_checkpoint_dir
):
...
@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files
=
len
(
files
)
total_files
=
len
(
files
)
state_dicts
=
[]
state_dicts
=
[]
for
f
in
files
:
for
f
in
files
:
state_dicts
.
append
(
torch
.
load
(
f
,
map_location
=
device
))
state_dict
=
torch
.
load
(
f
,
map_location
=
device
)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict
[
"optimizer_state_dict"
].
pop
(
"optimizer_state_dict"
,
None
)
state_dicts
.
append
(
state_dict
)
if
not
"zero_stage"
in
state_dicts
[
0
][
'optimizer_state_dict'
]:
if
not
ZERO_STAGE
in
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
]:
raise
ValueError
(
f
"
{
files
[
0
]
}
is not a zero checkpoint"
)
raise
ValueError
(
f
"
{
files
[
0
]
}
is not a zero checkpoint"
)
zero_stage
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"zero_stage"
]
zero_stage
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
ZERO_STAGE
]
world_size
=
state_dicts
[
0
][
'optimizer_state_dict'
][
"partition_count"
]
world_size
=
state_dicts
[
0
][
OPTIMIZER_STATE_DICT
][
PARTITION_COUNT
]
param_shapes
=
state_dicts
[
0
][
"param_shapes"
]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
# use the max of the partition_count to get the dp world_size.
...
@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
)
)
# the groups are named differently in each stage
# the groups are named differently in each stage
if
zero_stage
=
=
2
:
if
zero_stage
<
=
2
:
fp32_groups_key
=
"single_partition_of_fp32_groups"
fp32_groups_key
=
SINGLE_PARTITION_OF_FP32_GROUPS
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
fp32_groups_key
=
"fp32_flat_groups"
fp32_groups_key
=
FP32_FLAT_GROUPS
else
:
else
:
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
raise
ValueError
(
f
"unknown zero stage
{
zero_stage
}
"
)
if
zero_stage
==
2
:
if
zero_stage
<=
2
:
fp32_flat_groups
=
[
fp32_flat_groups
=
[
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))]
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
]
for
i
in
range
(
len
(
state_dicts
))
]
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
# if there is more than one param group, there will be multiple flattened tensors - one
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
# flattened tensor per group - for simplicity merge them into a single tensor
...
@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
...
@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups
=
[
fp32_flat_groups
=
[
torch
.
cat
(
state_dicts
[
i
][
'optimizer_state_dict'
][
fp32_groups_key
],
torch
.
cat
(
state_dicts
[
i
][
OPTIMIZER_STATE_DICT
][
fp32_groups_key
],
0
)
for
i
in
range
(
len
(
state_dicts
))
0
)
for
i
in
range
(
len
(
state_dicts
))
]
]
return
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
return
zero_stage
,
world_size
,
fp32_flat_groups
def
_get_fp32_state_dict_from_zero_checkpoint
(
ds_checkpoint_dir
):
def
_get_fp32_state_dict_from_zero_checkpoint
(
ds_checkpoint_dir
):
...
@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
...
@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print
(
f
"Processing zero checkpoint '
{
ds_checkpoint_dir
}
'"
)
print
(
f
"Processing zero checkpoint '
{
ds_checkpoint_dir
}
'"
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
world_size
,
param_shapes
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
zero_stage
,
world_size
,
fp32_flat_groups
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
print
(
print
(
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
f
"Detected checkpoint of type zero stage
{
zero_stage
}
, world_size:
{
world_size
}
"
)
model_files
=
get_model_state_files
(
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
buffers
=
parse_model_state
(
model_file
)
zero_model_states
=
parse_model_states
(
model_files
)
print
(
f
'Parsing checkpoint created by deepspeed==
{
zero_model_states
[
0
].
ds_version
}
'
)
if
zero_stage
==
2
:
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
if
zero_stage
<=
2
:
param_shapes
,
return
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
fp32_flat_groups
,
buffers
)
elif
zero_stage
==
3
:
elif
zero_stage
==
3
:
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
return
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
)
param_shapes
,
fp32_flat_groups
,
buffers
)
def
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
):
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
return
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
frozen_param_fragments
=
zero_model_states
[
0
].
frozen_param_fragments
if
debug
:
num_elem
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
print
(
f
'rank 0:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
frozen_param_fragments
.
values
()])
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
state_dict
[
name
]
=
frozen_param_fragments
[
name
]
if
debug
:
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
def
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
,
param_shapes
=
zero_model_states
[
0
].
param_shapes
fp32_flat_groups
,
buffers
):
# Reconstruction protocol:
# Reconstruction protocol:
#
#
...
@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if
debug
:
if
debug
:
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
for
j
in
range
(
len
(
fp32_flat_groups
[
0
])):
for
j
in
range
(
len
(
fp32_flat_groups
[
0
])):
print
(
f
"
fp32_flat_groups
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
][
{
j
}
].shape=
{
fp32_flat_groups
[
i
][
j
].
shape
}
"
)
# XXX: memory usage doubles here (zero2)
# XXX: memory usage doubles here (zero2)
num_param_groups
=
len
(
fp32_flat_groups
[
0
])
num_param_groups
=
len
(
fp32_flat_groups
[
0
])
...
@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions
=
[
sd
[
i
]
for
sd
in
fp32_flat_groups
]
merged_partitions
=
[
sd
[
i
]
for
sd
in
fp32_flat_groups
]
full_single_fp32_vector
=
torch
.
cat
(
merged_partitions
,
0
)
full_single_fp32_vector
=
torch
.
cat
(
merged_partitions
,
0
)
merged_single_partition_of_fp32_groups
.
append
(
full_single_fp32_vector
)
merged_single_partition_of_fp32_groups
.
append
(
full_single_fp32_vector
)
avail_numel
=
sum
([
avail_numel
=
sum
(
full_single_fp32_vector
.
numel
()
[
full_single_fp32_vector
.
numel
()
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
for
full_single_fp32_vector
in
merged_single_partition_of_fp32_groups
])
if
debug
:
if
debug
:
wanted_params
=
sum
([
len
(
shapes
)
for
shapes
in
param_shapes
])
wanted_params
=
sum
([
len
(
shapes
)
for
shapes
in
param_shapes
])
wanted_numel
=
sum
(
wanted_numel
=
sum
([
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
[
sum
(
shape
.
numel
()
for
shape
in
shapes
.
values
())
for
shapes
in
param_shapes
])
# not asserting if there is a mismatch due to possible padding
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Have
{
avail_numel
}
numels to process."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
# params
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
# out-of-core computing solution
...
@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params
+=
1
total_params
+=
1
if
debug
:
if
debug
:
print
(
print
(
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
)
f
"
{
name
}
full shape:
{
shape
}
unpartitioned numel
{
unpartitioned_numel
}
"
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
)
state_dict
[
name
]
=
full_single_fp32_vector
.
narrow
(
0
,
offset
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
unpartitioned_numel
offset
+=
unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...
@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
...
@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check
# Sanity check
if
offset
!=
avail_numel
:
if
offset
!=
avail_numel
:
raise
ValueError
(
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
print
(
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero2_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero2_merge_frozen_params
(
state_dict
,
zero_model_states
)
_zero2_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
return
state_dict
...
@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
...
@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return
partitioned_numel
,
padding_numel
return
partitioned_numel
,
padding_numel
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
def
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
):
param_shapes
,
if
zero_model_states
[
0
].
frozen_param_shapes
is
None
or
len
(
zero_model_states
[
0
].
frozen_param_shapes
)
==
0
:
fp32_flat_groups
,
return
buffers
):
if
debug
:
for
i
in
range
(
world_size
):
num_elem
=
sum
(
s
.
numel
()
for
s
in
zero_model_states
[
i
].
frozen_param_fragments
.
values
())
print
(
f
'rank
{
i
}
:
{
FROZEN_PARAM_SHAPES
}
.numel =
{
num_elem
}
'
)
frozen_param_shapes
=
zero_model_states
[
0
].
frozen_param_shapes
wanted_params
=
len
(
frozen_param_shapes
)
wanted_numel
=
sum
(
s
.
numel
()
for
s
in
frozen_param_shapes
.
values
())
avail_numel
=
sum
([
p
.
numel
()
for
p
in
zero_model_states
[
0
].
frozen_param_fragments
.
values
()])
*
world_size
print
(
f
'Frozen params: Have
{
avail_numel
}
numels to process.'
)
print
(
f
'Frozen params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params'
)
total_params
=
0
total_numel
=
0
for
name
,
shape
in
zero_model_states
[
0
].
frozen_param_shapes
.
items
():
total_params
+=
1
unpartitioned_numel
=
shape
.
numel
()
total_numel
+=
unpartitioned_numel
param_frags
=
tuple
(
model_state
.
frozen_param_fragments
[
name
]
for
model_state
in
zero_model_states
)
state_dict
[
name
]
=
torch
.
cat
(
param_frags
,
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
partitioned_numel
,
partitioned_padding_numel
=
zero3_partitioned_param_info
(
unpartitioned_numel
,
world_size
)
if
debug
:
print
(
f
"Frozen params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
print
(
f
"Reconstructed Frozen fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
):
param_shapes
=
zero_model_states
[
0
].
param_shapes
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
# param, re-consolidating each param, while dealing with padding if any
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
# merge list of dicts, preserving order
# merge list of dicts, preserving order
param_shapes
=
{
k
:
v
for
d
in
param_shapes
for
k
,
v
in
d
.
items
()}
param_shapes
=
{
k
:
v
for
d
in
param_shapes
for
k
,
v
in
d
.
items
()}
if
debug
:
if
debug
:
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
print
(
f
"
fp32_flat_groups
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
print
(
f
"
{
FP32_FLAT_GROUPS
}
[
{
i
}
].shape=
{
fp32_flat_groups
[
i
].
shape
}
"
)
wanted_params
=
len
(
param_shapes
)
wanted_params
=
len
(
param_shapes
)
wanted_numel
=
sum
(
shape
.
numel
()
for
shape
in
param_shapes
.
values
())
wanted_numel
=
sum
(
shape
.
numel
()
for
shape
in
param_shapes
.
values
())
# not asserting if there is a mismatch due to possible padding
# not asserting if there is a mismatch due to possible padding
print
(
f
"Have
{
avail_numel
}
numels to process."
)
avail_numel
=
fp32_flat_groups
[
0
].
numel
()
*
world_size
print
(
f
"Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
print
(
f
"Trainable params: Have
{
avail_numel
}
numels to process."
)
print
(
f
"Trainable params: Need
{
wanted_numel
}
numels in
{
wanted_params
}
params."
)
state_dict
=
OrderedDict
()
# buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
# params
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...
@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
...
@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if
debug
:
if
debug
:
print
(
print
(
f
"
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
f
"
Trainable params:
{
total_params
}
{
name
}
full shape:
{
shape
}
partition0 numel=
{
partitioned_numel
}
partitioned_padding_numel=
{
partitioned_padding_numel
}
"
)
)
# XXX: memory usage doubles here
# XXX: memory usage doubles here
state_dict
[
name
]
=
torch
.
cat
(
state_dict
[
name
]
=
torch
.
cat
(
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
tuple
(
fp32_flat_groups
[
i
].
narrow
(
0
,
offset
,
partitioned_numel
)
for
i
in
range
(
world_size
)),
offset
,
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
partitioned_numel
)
for
i
in
range
(
world_size
)),
0
).
narrow
(
0
,
0
,
unpartitioned_numel
).
view
(
shape
)
offset
+=
partitioned_numel
offset
+=
partitioned_numel
offset
*=
world_size
offset
*=
world_size
# Sanity check
# Sanity check
if
offset
!=
avail_numel
:
if
offset
!=
avail_numel
:
raise
ValueError
(
raise
ValueError
(
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
f
"consumed
{
offset
}
numels out of
{
avail_numel
}
- something is wrong"
)
print
(
print
(
f
"Reconstructed Trainable fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
f
"Reconstructed fp32 state dict with
{
total_params
}
params
{
total_numel
}
elements"
)
def
_get_fp32_state_dict_from_zero3_checkpoint
(
world_size
,
fp32_flat_groups
,
zero_model_states
):
state_dict
=
OrderedDict
()
# buffers
buffers
=
zero_model_states
[
0
].
buffers
state_dict
.
update
(
buffers
)
if
debug
:
print
(
f
"added
{
len
(
buffers
)
}
buffers"
)
_zero3_merge_frozen_params
(
state_dict
,
world_size
,
zero_model_states
)
_zero3_merge_trainable_params
(
state_dict
,
world_size
,
fp32_flat_groups
,
zero_model_states
)
# recover shared parameters
for
pair
in
zero_model_states
[
0
].
shared_params
:
if
pair
[
1
]
in
state_dict
:
state_dict
[
pair
[
0
]]
=
state_dict
[
pair
[
1
]]
return
state_dict
return
state_dict
...
@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
...
@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"checkpoint_dir"
,
"checkpoint_dir"
,
type
=
str
,
type
=
str
,
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
help
=
"path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
parser
.
add_argument
(
parser
.
add_argument
(
"output_file"
,
"output_file"
,
type
=
str
,
type
=
str
,
help
=
help
=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
parser
.
add_argument
(
"-t"
,
)
"--tag"
,
type
=
str
,
default
=
None
,
help
=
"checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
action
=
'store_true'
,
help
=
"enable debug"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
action
=
'store_true'
,
help
=
"enable debug"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
debug
=
args
.
debug
debug
=
args
.
debug
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
)
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
checkpoint_dir
,
args
.
output_file
,
tag
=
args
.
tag
)
train_openfold.py
View file @
9776b696
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
get_global_step_from_zero_checkpoint
)
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
@@ -294,8 +295,13 @@ def main(args):
...
@@ -294,8 +295,13 @@ def main(args):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
if
'module'
in
sd
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
module_sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
[
'module'
].
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
module_sd
)
elif
'state_dict'
in
sd
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
[
'state_dict'
])
else
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment