Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
775f77dd
Commit
775f77dd
authored
Feb 12, 2024
by
Jennifer
Browse files
bugfixes and adds a section to convert optim files
parent
260592e0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
25 deletions
+58
-25
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+10
-3
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+37
-22
train_openfold.py
train_openfold.py
+11
-0
No files found.
openfold/utils/import_weights.py
View file @
775f77dd
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
re
import
logging
from
enum
import
Enum
from
dataclasses
import
dataclass
from
functools
import
partial
...
...
@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
def
convert_deprecated_v1_keys
(
state_dict
):
"""Update older OpenFold model weight names to match the current model code."""
logging
.
warning
(
'converting keys...'
)
replacements
=
{
'template_angle_embedder'
:
'template_single_embedder'
,
...
...
@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict):
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
(
1
)],
key
)
### DEBUG: remove before final commit
if
key
==
'template_angle_embedder.linear_1.weight'
:
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
logging
.
warning
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
### DEBUG: remove before final commit
# Add prefix for template layers
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
if
template_match
:
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
# DEBUG: remove before final commit
if
key
==
'template_angle_embedder.linear_1.weight'
:
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
breakpoint
()
logging
.
warning
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
### DEBUG: remove before final commit
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
View file @
775f77dd
...
...
@@ -15,6 +15,7 @@
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
logging
import
argparse
import
os
import
shutil
...
...
@@ -23,47 +24,61 @@ import torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
get_latest_checkpoint_dir
(
checkpoint_dir
):
# Based on zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
os
.
path
.
join
(
checkpoint_dir
,
tag
)
def
convert_v1_to_v2_weights
(
args
):
# TODO can we use zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint here?
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
# A DeepSpeed checkpoint
ds_checkpoint_path
=
get_latest_checkpoint_dir
(
checkpoint_path
)
logging
.
info
(
'Converting checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'module'
optim_files
=
get_optim_files
(
ds_checkpoint_path
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_path
)
model_file
=
get_model_state_file
(
ds_checkpoint_path
,
zero_stage
)
latest_path
=
os
.
path
.
join
(
checkpoint_path
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
ds_checkpoint_dir
=
os
.
path
.
join
(
checkpoint_path
,
tag
)
model_output_path
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
tag
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
else
:
# A Pytorch Lightning checkpoint
state_dict_key
=
'state_dict'
model_output_path
=
args
.
output_ckpt_path
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
param_shapes
=
convert_deprecated_v1_keys
(
model_dict
[
'param_shapes'
][
0
])
model_dict
[
'param_shapes'
]
=
[
param_shapes
]
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
os
.
path
.
basename
(
model_file
))
out_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
model_file
))
for
optim_file
in
optim_files
:
optim_dict
=
torch
.
load
(
optim_file
)
new_optim_dict
=
optim_dict
.
copy
()
new_optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
]
=
convert_deprecated_v1_keys
(
optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
])
out_optim_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
optim_file
))
torch
.
save
(
new_optim_dict
,
out_optim_fname
)
else
:
out_fname
=
args
.
output_ckp
t_path
out_fname
=
model_outpu
t_path
torch
.
save
(
model_dict
,
out_fname
)
...
...
@@ -75,4 +90,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
\ No newline at end of file
convert_v1_to_v2_weights
(
args
)
train_openfold.py
View file @
775f77dd
...
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
...
@@ -288,6 +289,16 @@ def main(args):
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
### DEBUG:
ds_checkpoint_dir
=
os
.
path
.
join
(
args
.
resume_from_ckpt
,
'global_step210'
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
###
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment