Commit 3f592307 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Reduce redundancy in seq embedding config presets

parent 6aefa986
...@@ -154,38 +154,20 @@ def model_config( ...@@ -154,38 +154,20 @@ def model_config(
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS # SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training": elif name == "seqemb_initial_training":
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
elif name == "seqemb_finetuning": elif name == "seqemb_finetuning":
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b": elif name == "seq_model_esm1b":
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1 c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm": elif name == "seq_model_esm1b_ptm":
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True c.model.template.enabled = True
...@@ -195,6 +177,14 @@ def model_config( ...@@ -195,6 +177,14 @@ def model_config(
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if name.startswith("seq"):
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment