Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f65b75fe
Commit
f65b75fe
authored
Nov 03, 2023
by
Christina Floristean
Browse files
Fix for loading old OF weights into refactored model
parent
5fcd6ed2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
7 deletions
+72
-7
openfold/model/structure_module.py
openfold/model/structure_module.py
+4
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+42
-0
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+3
-2
scripts/convert_of_weights_to_jax.py
scripts/convert_of_weights_to_jax.py
+2
-1
tests/test_import_weights.py
tests/test_import_weights.py
+19
-1
train_openfold.py
train_openfold.py
+2
-1
No files found.
openfold/model/structure_module.py
View file @
f65b75fe
...
@@ -174,8 +174,10 @@ class PointProjection(nn.Module):
...
@@ -174,8 +174,10 @@ class PointProjection(nn.Module):
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
num_points
=
num_points
self
.
num_points
=
num_points
self
.
is_multimer
=
is_multimer
self
.
is_multimer
=
is_multimer
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
,
precision
=
torch
.
float32
)
# Multimer requires this to be run with fp32 precision during training
precision
=
torch
.
float32
if
self
.
is_multimer
else
None
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
,
precision
=
precision
)
def
forward
(
self
,
def
forward
(
self
,
activations
:
torch
.
Tensor
,
activations
:
torch
.
Tensor
,
...
...
openfold/utils/import_weights.py
View file @
f65b75fe
...
@@ -665,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -665,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights
# Set weights
assign
(
flat
,
data
)
assign
(
flat
,
data
)
def
convert_deprecated_v1_keys
(
state_dict
):
"""Update older OpenFold model weight names to match the current model code."""
replacements
=
{
'template_angle_embedder'
:
'template_single_embedder'
,
'core.msa_transition'
:
'msa_transition'
,
'core.outer_product_mean'
:
'outer_product_mean'
,
'core.tri_'
:
'pair_stack.tri_'
,
'core.pair_transition'
:
'pair_stack.pair_transition'
,
'ipa.linear_q_points'
:
'ipa.linear_q_points.linear'
,
'ipa.linear_kv_points'
:
'ipa.linear_kv_points.linear'
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
# Add prefix for template modules
if
new_key
.
startswith
(
'template'
):
new_key
=
f
'template_embedder.
{
new_key
}
'
converted_state_dict
[
new_key
]
=
value
return
converted_state_dict
def
import_openfold_weights_
(
model
,
state_dict
):
"""
Import model weights. Several parts of the model were refactored in the process
of adding support for Multimer. The state dicts of older models are translated
to match the refactored model code.
"""
try
:
model
.
load_state_dict
(
state_dict
)
except
RuntimeError
:
converted_state_dict
=
convert_deprecated_v1_keys
(
state_dict
)
model
.
load_state_dict
(
converted_state_dict
)
openfold/utils/script_utils.py
View file @
f65b75fe
...
@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein
...
@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein
from
openfold.np.relax
import
relax
from
openfold.np.relax
import
relax
from
openfold.utils.import_weights
import
(
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_jax_weights_
,
import_openfold_weights_
)
)
from
pytorch_lightning.utilities.deepspeed
import
(
from
pytorch_lightning.utilities.deepspeed
import
(
...
@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
...
@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
ckpt_path
,
ckpt_path
,
)
)
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_
state_dict
(
d
[
"ema"
][
"params"
])
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
[
"ema"
][
"params"
])
else
:
else
:
ckpt_path
=
path
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
...
@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
...
@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
if
"ema"
in
d
:
if
"ema"
in
d
:
# The public weights have had this done to them already
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
d
=
d
[
"ema"
][
"params"
]
model
.
load_
state_dict
(
d
)
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
)
model
=
model
.
to
(
model_device
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
logger
.
info
(
...
...
scripts/convert_of_weights_to_jax.py
View file @
f65b75fe
...
@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
...
@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
ParamType
,
ParamType
,
generate_translation_dict
,
generate_translation_dict
,
process_translation_dict
,
process_translation_dict
,
import_openfold_weights_
)
)
from
openfold.utils.tensor_utils
import
tree_map
from
openfold.utils.tensor_utils
import
tree_map
...
@@ -63,7 +64,7 @@ def main(args):
...
@@ -63,7 +64,7 @@ def main(args):
config
=
model_config
(
args
.
config_preset
)
config
=
model_config
(
args
.
config_preset
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
.
load_
state_dict
(
d
)
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
)
translation
=
generate_translation_dict
(
model
,
args
.
config_preset
)
translation
=
generate_translation_dict
(
model
,
args
.
config_preset
)
translation
=
process_translation_dict
(
translation
)
translation
=
process_translation_dict
(
translation
)
...
...
tests/test_import_weights.py
View file @
f65b75fe
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -20,7 +21,7 @@ from pathlib import Path
...
@@ -20,7 +21,7 @@ from pathlib import Path
from
tests.config
import
consts
from
tests.config
import
consts
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.import_weights
import
import_jax_weights_
,
import_openfold_weights_
class
TestImportWeights
(
unittest
.
TestCase
):
class
TestImportWeights
(
unittest
.
TestCase
):
...
@@ -75,3 +76,20 @@ class TestImportWeights(unittest.TestCase):
...
@@ -75,3 +76,20 @@ class TestImportWeights(unittest.TestCase):
for
w_alpha
,
w_repro
in
test_pairs
:
for
w_alpha
,
w_repro
in
test_pairs
:
self
.
assertTrue
(
torch
.
all
(
w_alpha
==
w_repro
))
self
.
assertTrue
(
torch
.
all
(
w_alpha
==
w_repro
))
def
test_import_openfold_weights_
(
self
):
model_name
=
'initial_training'
pt_path
=
Path
(
__file__
).
parent
.
resolve
()
/
f
"../openfold/resources/openfold_params/
{
model_name
}
.pt"
if
os
.
path
.
exists
(
pt_path
):
c
=
model_config
(
model_name
)
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
)
model
.
eval
()
d
=
torch
.
load
(
pt_path
)
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
,
)
train_openfold.py
View file @
f65b75fe
...
@@ -33,6 +33,7 @@ from openfold.utils.validation_metrics import (
...
@@ -33,6 +33,7 @@ from openfold.utils.validation_metrics import (
)
)
from
openfold.utils.import_weights
import
(
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_jax_weights_
,
import_openfold_weights_
)
)
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
,
get_fp32_state_dict_from_zero_checkpoint
,
...
@@ -293,7 +294,7 @@ def main(args):
...
@@ -293,7 +294,7 @@ def main(args):
else
:
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_
state_dict
(
sd
)
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment