Unverified Commit 7d10334b authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fix weights import function (#124)

parent f55dca95
...@@ -127,7 +127,8 @@ def assign(translation_dict, orig_weights): ...@@ -127,7 +127,8 @@ def assign(translation_dict, orig_weights):
print(weights[0].shape) print(weights[0].shape)
raise raise
def get_translation_dict(model, is_multimer: bool = False): def get_translation_dict(model, version):
is_multimer = "multimer" in version
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -537,15 +538,6 @@ def get_translation_dict(model, is_multimer: bool = False): ...@@ -537,15 +538,6 @@ def get_translation_dict(model, is_multimer: bool = False):
}, },
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, is_multimer=("multimer" in version))
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
...@@ -566,6 +558,14 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -566,6 +558,14 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, version)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations) flat = _process_translations_dict(translations)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment