Commit b1780f4e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in weight import script

parent 4c37c386
...@@ -322,11 +322,6 @@ def generate_translation_dict(model, version): ...@@ -322,11 +322,6 @@ def generate_translation_dict(model, version):
# translations dict overflow # translations dict overflow
############################ ############################
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.blocks ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
...@@ -349,30 +344,10 @@ def generate_translation_dict(model, version): ...@@ -349,30 +344,10 @@ def generate_translation_dict(model, version):
"pair_activiations": LinearParams( "pair_activiations": LinearParams(
model.input_embedder.linear_relpos model.input_embedder.linear_relpos
), ),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams( "extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear model.extra_msa_embedder.linear
), ),
"extra_msa_stack": ems_blocks_params, "extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params, "evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear), "single_activations": LinearParams(model.evoformer.linear),
}, },
...@@ -417,12 +392,36 @@ def generate_translation_dict(model, version): ...@@ -417,12 +392,36 @@ def generate_translation_dict(model, version):
"model_4_ptm", "model_4_ptm",
"model_5_ptm", "model_5_ptm",
] ]
if version in no_templ: if version in no_templ:
evo_dict = translations["evoformer"] tps_blocks = model.template_pair_stack.blocks
keys = list(evo_dict.keys()) tps_blocks_params = stacked(
for k in keys: [TemplatePairBlockParams(b) for b in tps_blocks]
if "template_" in k: )
evo_dict.pop(k)
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
}
translations["evoformer"].update(template_param_dict)
if "_ptm" in version: if "_ptm" in version:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment