Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
260592e0
Commit
260592e0
authored
Feb 09, 2024
by
Jennifer
Browse files
Adjust weight conversion and add a script for converting checkpoints.
parent
1df591b0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
5 deletions
+89
-5
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+11
-5
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+78
-0
No files found.
openfold/utils/import_weights.py
View file @
260592e0
...
@@ -681,16 +681,22 @@ def convert_deprecated_v1_keys(state_dict):
...
@@ -681,16 +681,22 @@ def convert_deprecated_v1_keys(state_dict):
}
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
template_emb_re
=
re
.
compile
(
"((module
\\
.)?(model
\\
.))?(template(?!_embedder).*)"
)
converted_state_dict
=
{}
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
if
key
==
'template_angle_embedder.linear_1.weight'
:
# Add prefix for template modules
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
subheader
=
re
.
search
(
'(?<=model.).*$'
,
new_key
).
group
()
if
subheader
.
startswith
(
'template'
):
# Add prefix for template layers
new_key
=
f
'model.template_embedder.
{
subheader
}
'
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
if
template_match
:
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
if
key
==
'template_angle_embedder.linear_1.weight'
:
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
converted_state_dict
[
new_key
]
=
value
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
0 → 100755
View file @
260592e0
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
argparse
import
os
import
shutil
import
torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
get_latest_checkpoint_dir
(
checkpoint_dir
):
# Based on zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
os
.
path
.
join
(
checkpoint_dir
,
tag
)
def
convert_v1_to_v2_weights
(
args
):
# TODO can we use zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint here?
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
# A DeepSpeed checkpoint
ds_checkpoint_path
=
get_latest_checkpoint_dir
(
checkpoint_path
)
state_dict_key
=
'module'
optim_files
=
get_optim_files
(
ds_checkpoint_path
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_path
)
model_file
=
get_model_state_file
(
ds_checkpoint_path
,
zero_stage
)
else
:
# A Pytorch Lightning checkpoint
state_dict_key
=
'state_dict'
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
os
.
path
.
basename
(
model_file
))
else
:
out_fname
=
args
.
output_ckpt_path
torch
.
save
(
model_dict
,
out_fname
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_ckpt_path"
,
type
=
str
)
parser
.
add_argument
(
"output_ckpt_path"
,
type
=
str
)
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment