"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "b0655a3465904ff265bdc1e1ccdcff009a448bb0"
Commit f65b75fe authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix for loading old OF weights into refactored model

parent 5fcd6ed2
...@@ -174,8 +174,10 @@ class PointProjection(nn.Module): ...@@ -174,8 +174,10 @@ class PointProjection(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.num_points = num_points self.num_points = num_points
self.is_multimer = is_multimer self.is_multimer = is_multimer
self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=torch.float32) # Multimer requires this to be run with fp32 precision during training
precision = torch.float32 if self.is_multimer else None
self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision)
def forward(self, def forward(self,
activations: torch.Tensor, activations: torch.Tensor,
......
...@@ -665,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -665,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights # Set weights
assign(flat, data) assign(flat, data)
def convert_deprecated_v1_keys(state_dict):
"""Update older OpenFold model weight names to match the current model code."""
replacements = {
'template_angle_embedder': 'template_single_embedder',
'core.msa_transition': 'msa_transition',
'core.outer_product_mean': 'outer_product_mean',
'core.tri_': 'pair_stack.tri_',
'core.pair_transition': 'pair_stack.pair_transition',
'ipa.linear_q_points': 'ipa.linear_q_points.linear',
'ipa.linear_kv_points': 'ipa.linear_kv_points.linear'
}
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
converted_state_dict = {}
for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
# Add prefix for template modules
if new_key.startswith('template'):
new_key = f'template_embedder.{new_key}'
converted_state_dict[new_key] = value
return converted_state_dict
def import_openfold_weights_(model, state_dict):
"""
Import model weights. Several parts of the model were refactored in the process
of adding support for Multimer. The state dicts of older models are translated
to match the refactored model code.
"""
try:
model.load_state_dict(state_dict)
except RuntimeError:
converted_state_dict = convert_deprecated_v1_keys(state_dict)
model.load_state_dict(converted_state_dict)
...@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein ...@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein
from openfold.np.relax import relax from openfold.np.relax import relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
import_openfold_weights_
) )
from pytorch_lightning.utilities.deepspeed import ( from pytorch_lightning.utilities.deepspeed import (
...@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path ...@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
ckpt_path, ckpt_path,
) )
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"]) import_openfold_weights_(model=model, state_dict=d["ema"]["params"])
else: else:
ckpt_path = path ckpt_path = path
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
...@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path ...@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
if "ema" in d: if "ema" in d:
# The public weights have had this done to them already # The public weights have had this done to them already
d = d["ema"]["params"] d = d["ema"]["params"]
model.load_state_dict(d) import_openfold_weights_(model=model, state_dict=d)
model = model.to(model_device) model = model.to(model_device)
logger.info( logger.info(
......
...@@ -26,6 +26,7 @@ from openfold.utils.import_weights import ( ...@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
ParamType, ParamType,
generate_translation_dict, generate_translation_dict,
process_translation_dict, process_translation_dict,
import_openfold_weights_
) )
from openfold.utils.tensor_utils import tree_map from openfold.utils.tensor_utils import tree_map
...@@ -63,7 +64,7 @@ def main(args): ...@@ -63,7 +64,7 @@ def main(args):
config = model_config(args.config_preset) config = model_config(args.config_preset)
model = AlphaFold(config) model = AlphaFold(config)
model.load_state_dict(d) import_openfold_weights_(model=model, state_dict=d)
translation = generate_translation_dict(model, args.config_preset) translation = generate_translation_dict(model, args.config_preset)
translation = process_translation_dict(translation) translation = process_translation_dict(translation)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -20,7 +21,7 @@ from pathlib import Path ...@@ -20,7 +21,7 @@ from pathlib import Path
from tests.config import consts from tests.config import consts
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_
class TestImportWeights(unittest.TestCase): class TestImportWeights(unittest.TestCase):
...@@ -75,3 +76,20 @@ class TestImportWeights(unittest.TestCase): ...@@ -75,3 +76,20 @@ class TestImportWeights(unittest.TestCase):
for w_alpha, w_repro in test_pairs: for w_alpha, w_repro in test_pairs:
self.assertTrue(torch.all(w_alpha == w_repro)) self.assertTrue(torch.all(w_alpha == w_repro))
def test_import_openfold_weights_(self):
model_name = 'initial_training'
pt_path = Path(__file__).parent.resolve() / f"../openfold/resources/openfold_params/{model_name}.pt"
if os.path.exists(pt_path):
c = model_config(model_name)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
d = torch.load(pt_path)
import_openfold_weights_(
model=model,
state_dict=d,
)
...@@ -33,6 +33,7 @@ from openfold.utils.validation_metrics import ( ...@@ -33,6 +33,7 @@ from openfold.utils.validation_metrics import (
) )
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
import_openfold_weights_
) )
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
...@@ -293,7 +294,7 @@ def main(args): ...@@ -293,7 +294,7 @@ def main(args):
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment