"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "f50c306af3e4837184d6a9e5219db72b1ecc3cf1"
Commit 03518fd1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add use_low_memory_attention parameters

parent 984370ce
......@@ -44,13 +44,14 @@ def enforce_config_constraints(config):
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
......@@ -64,22 +65,23 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120
......@@ -172,7 +174,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(False, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
......@@ -334,7 +336,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
"num_workers": 8,
},
},
},
......@@ -346,7 +348,7 @@ config = mlc.ConfigDict(
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
......
......@@ -257,6 +257,7 @@ class EvoformerBlockCore(nn.Module):
z,
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -275,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
z,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -386,6 +388,7 @@ class EvoformerBlock(nn.Module):
z=z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
)
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment