Commit f5113baf authored by zhuww's avatar zhuww
Browse files

add the func of parsing _ptm model

parent 5a7db20f
...@@ -126,8 +126,11 @@ def assign(translation_dict, orig_weights): ...@@ -126,8 +126,11 @@ def assign(translation_dict, orig_weights):
print(ref[0].shape) print(ref[0].shape)
print(weights[0].shape) print(weights[0].shape)
raise raise
def get_translation_dict(model, is_multimer: bool = False):
def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = False):
data = np.load(npz_path)
# translations = get_translation_dict(model, is_multimer=("multimer" in version))
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -537,23 +540,16 @@ def get_translation_dict(model, is_multimer: bool = False): ...@@ -537,23 +540,16 @@ def get_translation_dict(model, is_multimer: bool = False):
}, },
} }
return translations # return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, is_multimer=("multimer" in version))
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
"model_5", "model_5",
"model_3_ptm", "model_3_ptm",
"model_4_ptm", "model_4_ptm",
"model_5_ptm", "model_5_ptm",
] ]
if version in no_templ: if version in no_templ:
evo_dict = translations["evoformer"] evo_dict = translations["evoformer"]
keys = list(evo_dict.keys()) keys = list(evo_dict.keys())
...@@ -582,3 +578,5 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -582,3 +578,5 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights # Set weights
assign(flat, data) assign(flat, data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment