Commit e49c6dc7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add long sequence inference preset

parent 14a79a12
......@@ -215,11 +215,14 @@ see the aforementioned Staats & Rabe preprint.
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time.
efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions
at once.
### Training
......
......@@ -41,8 +41,19 @@ def enforce_config_constraints(config):
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(name, train=False, low_prec=False):
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training":
......@@ -144,6 +155,16 @@ def model_config(name, train=False, low_prec=False):
else:
raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
......@@ -151,6 +172,7 @@ def model_config(name, train=False, low_prec=False):
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
......@@ -426,7 +448,8 @@ config = mlc.ConfigDict(
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
},
"extra_msa": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment