Commit 03518fd1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add use_low_memory_attention parameters

parent 984370ce
...@@ -44,13 +44,14 @@ def enforce_config_constraints(config): ...@@ -44,13 +44,14 @@ def enforce_config_constraints(config):
def model_config(name, train=False, low_prec=False): def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training": if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting # AF2 Suppl. Table 4, "initial training" setting
pass pass
elif name == "finetuning": elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
...@@ -64,22 +65,23 @@ def model_config(name, train=False, low_prec=False): ...@@ -64,22 +65,23 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ": elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.model.template.enabled = False c.model.template.enabled = False
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm": elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.model.template.enabled = False c.model.template.enabled = False
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
...@@ -172,7 +174,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool) ...@@ -172,7 +174,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool) tune_chunk_size = mlc.FieldReference(False, field_type=bool)
NUM_RES = "num residues placeholder" NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder" NUM_MSA_SEQ = "msa placeholder"
...@@ -334,7 +336,7 @@ config = mlc.ConfigDict( ...@@ -334,7 +336,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 16, "num_workers": 8,
}, },
}, },
}, },
...@@ -346,7 +348,7 @@ config = mlc.ConfigDict( ...@@ -346,7 +348,7 @@ config = mlc.ConfigDict(
# exclusive with use_flash. # exclusive with use_flash.
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with # Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. # use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False, "use_flash": False,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
......
...@@ -257,6 +257,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -257,6 +257,7 @@ class EvoformerBlockCore(nn.Module):
z, z,
mask=pair_mask, mask=pair_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -275,6 +276,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -275,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
z, z,
mask=pair_mask.transpose(-1, -2), mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -386,6 +388,7 @@ class EvoformerBlock(nn.Module): ...@@ -386,6 +388,7 @@ class EvoformerBlock(nn.Module):
z=z, z=z,
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment