Commit e49c6dc7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add long sequence inference preset

parent 14a79a12
...@@ -215,11 +215,14 @@ see the aforementioned Staats & Rabe preprint. ...@@ -215,11 +215,14 @@ see the aforementioned Staats & Rabe preprint.
wastes time. wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more - As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model. extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory 4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time. efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions
at once.
### Training ### Training
......
...@@ -41,8 +41,19 @@ def enforce_config_constraints(config): ...@@ -41,8 +41,19 @@ def enforce_config_constraints(config):
if(config.globals.use_flash and not fa_is_installed): if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed") raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(name, train=False, low_prec=False):
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS # TRAINING PRESETS
if name == "initial_training": if name == "initial_training":
...@@ -144,6 +155,16 @@ def model_config(name, train=False, low_prec=False): ...@@ -144,6 +155,16 @@ def model_config(name, train=False, low_prec=False):
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
...@@ -151,6 +172,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -151,6 +172,7 @@ def model_config(name, train=False, low_prec=False):
c.globals.offload_inference = False c.globals.offload_inference = False
c.model.template.average_templates = False c.model.template.average_templates = False
c.model.template.offload_templates = False c.model.template.offload_templates = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be # If we want exact numerical parity with the original, inf can't be
...@@ -426,7 +448,8 @@ config = mlc.ConfigDict( ...@@ -426,7 +448,8 @@ config = mlc.ConfigDict(
# Offload template embeddings to CPU memory. Vastly reduced # Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in # memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences. # runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. # Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False, "offload_templates": False,
}, },
"extra_msa": { "extra_msa": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment