Commit 423e7ecf authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix typo

parent a7c4159d
...@@ -52,14 +52,14 @@ class Param: ...@@ -52,14 +52,14 @@ class Param:
stacked: bool = False stacked: bool = False
def _process_translations_dict(d, top_layer=True): def process_translation_dict(d, top_layer=True):
flat = {} flat = {}
for k, v in d.items(): for k, v in d.items():
if type(v) == dict: if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else "" prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = { sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime (prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict( for k_prime, v_prime in process_translation_dict(
v, top_layer=False v, top_layer=False
).items() ).items()
} }
...@@ -122,9 +122,7 @@ def assign(translation_dict, orig_weights): ...@@ -122,9 +122,7 @@ def assign(translation_dict, orig_weights):
raise raise
def import_jax_weights_(model, npz_path, version="model_1"): def generate_translation_dict(model, version):
data = np.load(npz_path)
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -431,8 +429,17 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -431,8 +429,17 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = generate_translation_dict(model, version)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations) flat = process_translation_dict(translations)
# Sanity check # Sanity check
keys = list(data.keys()) keys = list(data.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment