Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9776b696
Commit
9776b696
authored
Feb 21, 2024
by
jnwei
Browse files
Merge weight-loading changes into setup-improvements
parents
9f346d35
ddfccd56
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
366 additions
and
128 deletions
+366
-128
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+8
-4
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+95
-0
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+255
-122
train_openfold.py
train_openfold.py
+8
-2
No files found.
openfold/utils/import_weights.py
View file @
9776b696
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
re
import
logging
from
enum
import
Enum
from
dataclasses
import
dataclass
from
functools
import
partial
...
...
@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
template_emb_re
=
re
.
compile
(
r
"^((module\.)?(model\.)?)(template(?!_embedder).*)"
)
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
(
1
)],
key
)
# Add prefix for template modules
if
new_key
.
startswith
(
'template'
):
new_key
=
f
'template_embedder.
{
new_key
}
'
# Add prefix for template layers
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
if
template_match
:
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
0 → 100755
View file @
9776b696
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
logging
import
argparse
import
os
import
shutil
import
torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
convert_v1_to_v2_weights
(
args
):
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
# A DeepSpeed checkpoint
logging
.
info
(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'module'
latest_path
=
os
.
path
.
join
(
checkpoint_path
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
ds_checkpoint_dir
=
os
.
path
.
join
(
checkpoint_path
,
tag
)
model_output_path
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
tag
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
else
:
# A Pytorch Lightning checkpoint
logging
.
info
(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'state_dict'
model_output_path
=
args
.
output_ckpt_path
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
param_shapes
=
convert_deprecated_v1_keys
(
model_dict
[
'param_shapes'
][
0
])
model_dict
[
'param_shapes'
]
=
[
param_shapes
]
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
model_file
))
for
optim_file
in
optim_files
:
optim_dict
=
torch
.
load
(
optim_file
)
new_optim_dict
=
optim_dict
.
copy
()
new_optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
]
=
convert_deprecated_v1_keys
(
optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
])
out_optim_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
optim_file
))
torch
.
save
(
new_optim_dict
,
out_optim_fname
)
else
:
out_fname
=
model_output_path
torch
.
save
(
model_dict
,
out_fname
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_ckpt_path"
,
type
=
str
)
parser
.
add_argument
(
"output_ckpt_path"
,
type
=
str
)
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
scripts/zero_to_fp32.py
View file @
9776b696
This diff is collapsed.
Click to expand it.
train_openfold.py
View file @
9776b696
...
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
...
@@ -294,8 +295,13 @@ def main(args):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
if
'module'
in
sd
:
module_sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
[
'module'
].
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
module_sd
)
elif
'state_dict'
in
sd
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
[
'state_dict'
])
else
:
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment