Commit 16d75057 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix config presets

parent 17015611
...@@ -12,27 +12,61 @@ def set_inf(c, inf): ...@@ -12,27 +12,61 @@ def set_inf(c, inf):
def model_config(name, train=False, low_prec=False): def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
if name == "model_1": if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_2": elif name == "model_2":
pass # AF2 Suppl. Table 5, Model 1.1.2
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_3": elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_4": elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_5": elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_1_ptm": elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_2_ptm": elif name == "model_2_ptm":
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_3_ptm": elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_4_ptm": elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
...@@ -53,6 +87,8 @@ def model_config(name, train=False, low_prec=False): ...@@ -53,6 +87,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
if tm:
return c return c
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment