Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dd1dd641
Commit
dd1dd641
authored
Feb 12, 2024
by
Jennifer
Browse files
bugfixes and adds a section to convert optim files
parent
0a0dbb39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
25 deletions
+58
-25
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+10
-3
scripts/convert_v1_to_v2_weights.py
scripts/convert_v1_to_v2_weights.py
+37
-22
train_openfold.py
train_openfold.py
+11
-0
No files found.
openfold/utils/import_weights.py
View file @
dd1dd641
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
import
logging
from
enum
import
Enum
from
enum
import
Enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
...
@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
def
convert_deprecated_v1_keys
(
state_dict
):
def
convert_deprecated_v1_keys
(
state_dict
):
"""Update older OpenFold model weight names to match the current model code."""
"""Update older OpenFold model weight names to match the current model code."""
logging
.
warning
(
'converting keys...'
)
replacements
=
{
replacements
=
{
'template_angle_embedder'
:
'template_single_embedder'
,
'template_angle_embedder'
:
'template_single_embedder'
,
...
@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict):
...
@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict):
converted_state_dict
=
{}
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
(
1
)],
key
)
### DEBUG: remove before final commit
if
key
==
'template_angle_embedder.linear_1.weight'
:
if
key
==
'template_angle_embedder.linear_1.weight'
:
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
logging
.
warning
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
### DEBUG: remove before final commit
# Add prefix for template layers
# Add prefix for template layers
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
template_match
=
re
.
match
(
template_emb_re
,
new_key
)
if
template_match
:
if
template_match
:
prefix
=
template_match
.
group
(
1
)
prefix
=
template_match
.
group
(
1
)
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
new_key
=
f
'
{
prefix
if
prefix
else
""
}
template_embedder.
{
template_match
.
group
(
4
)
}
'
# DEBUG: remove before final commit
if
key
==
'template_angle_embedder.linear_1.weight'
:
if
key
==
'template_angle_embedder.linear_1.weight'
:
print
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
breakpoint
()
logging
.
warning
(
f
'old key:
{
key
}
, new_key:
{
new_key
}
'
)
### DEBUG: remove before final commit
converted_state_dict
[
new_key
]
=
value
converted_state_dict
[
new_key
]
=
value
...
...
scripts/convert_v1_to_v2_weights.py
View file @
dd1dd641
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
# used to run inference using DeepMind's JAX code.
import
logging
import
argparse
import
argparse
import
os
import
os
import
shutil
import
shutil
...
@@ -23,47 +24,61 @@ import torch
...
@@ -23,47 +24,61 @@ import torch
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
openfold.utils.import_weights
import
convert_deprecated_v1_keys
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
def
get_latest_checkpoint_dir
(
checkpoint_dir
):
# Based on zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
os
.
path
.
join
(
checkpoint_dir
,
tag
)
def
convert_v1_to_v2_weights
(
args
):
def
convert_v1_to_v2_weights
(
args
):
# TODO can we use zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint here?
checkpoint_path
=
args
.
input_ckpt_path
checkpoint_path
=
args
.
input_ckpt_path
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
is_dir
=
os
.
path
.
isdir
(
checkpoint_path
)
if
is_dir
:
if
is_dir
:
# A DeepSpeed checkpoint
# A DeepSpeed checkpoint
ds_checkpoint_path
=
get_latest_checkpoint_dir
(
checkpoint_path
)
logging
.
info
(
'Converting checkpoint found at {args.input_checkpoint_path}'
)
state_dict_key
=
'module'
state_dict_key
=
'module'
optim_files
=
get_optim_files
(
ds_checkpoint_path
)
latest_path
=
os
.
path
.
join
(
checkpoint_path
,
'latest'
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_path
)
if
os
.
path
.
isfile
(
latest_path
):
model_file
=
get_model_state_file
(
ds_checkpoint_path
,
zero_stage
)
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
ds_checkpoint_dir
=
os
.
path
.
join
(
checkpoint_path
,
tag
)
model_output_path
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
tag
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
else
:
else
:
# A Pytorch Lightning checkpoint
# A Pytorch Lightning checkpoint
state_dict_key
=
'state_dict'
state_dict_key
=
'state_dict'
model_output_path
=
args
.
output_ckpt_path
model_file
=
checkpoint_path
model_file
=
checkpoint_path
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
model_dict
[
state_dict_key
]
=
convert_deprecated_v1_keys
(
model_dict
[
state_dict_key
])
if
'ema'
in
model_dict
:
if
'ema'
in
model_dict
:
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
ema_state_dict
=
model_dict
[
'ema'
][
'params'
]
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
model_dict
[
'ema'
][
'params'
]
=
convert_deprecated_v1_keys
(
ema_state_dict
)
if
is_dir
:
if
is_dir
:
param_shapes
=
convert_deprecated_v1_keys
(
model_dict
[
'param_shapes'
][
0
])
model_dict
[
'param_shapes'
]
=
[
param_shapes
]
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
shutil
.
copytree
(
checkpoint_path
,
args
.
output_ckpt_path
)
out_fname
=
os
.
path
.
join
(
args
.
output_ckpt_path
,
os
.
path
.
basename
(
model_file
))
out_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
model_file
))
for
optim_file
in
optim_files
:
optim_dict
=
torch
.
load
(
optim_file
)
new_optim_dict
=
optim_dict
.
copy
()
new_optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
]
=
convert_deprecated_v1_keys
(
optim_dict
[
'optimizer_state_dict'
][
'param_slice_mappings'
][
0
])
out_optim_fname
=
os
.
path
.
join
(
model_output_path
,
os
.
path
.
basename
(
optim_file
))
torch
.
save
(
new_optim_dict
,
out_optim_fname
)
else
:
else
:
out_fname
=
args
.
output_ckp
t_path
out_fname
=
model_outpu
t_path
torch
.
save
(
model_dict
,
out_fname
)
torch
.
save
(
model_dict
,
out_fname
)
...
@@ -75,4 +90,4 @@ if __name__ == "__main__":
...
@@ -75,4 +90,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_v1_to_v2_weights
(
args
)
convert_v1_to_v2_weights
(
args
)
\ No newline at end of file
train_openfold.py
View file @
dd1dd641
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
...
@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
get_global_step_from_zero_checkpoint
)
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
@@ -288,6 +289,16 @@ def main(args):
...
@@ -288,6 +289,16 @@ def main(args):
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
model_module
.
resume_last_lr_step
(
last_global_step
)
### DEBUG:
ds_checkpoint_dir
=
os
.
path
.
join
(
args
.
resume_from_ckpt
,
'global_step210'
)
optim_files
=
get_optim_files
(
ds_checkpoint_dir
)
zero_stage
,
_
,
_
=
parse_optim_states
(
optim_files
,
ds_checkpoint_dir
)
model_file
=
get_model_state_file
(
ds_checkpoint_dir
,
zero_stage
)
model_dict
=
torch
.
load
(
model_file
,
map_location
=
torch
.
device
(
'cpu'
))
###
logging
.
info
(
"Successfully loaded last lr step..."
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment