Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e49c6dc7
Commit
e49c6dc7
authored
Aug 05, 2022
by
Gustaf Ahdritz
Browse files
Add long sequence inference preset
parent
14a79a12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
3 deletions
+29
-3
README.md
README.md
+4
-1
openfold/config.py
openfold/config.py
+25
-2
No files found.
README.md
View file @
e49c6dc7
...
...
@@ -215,11 +215,14 @@ see the aforementioned Staats & Rabe preprint.
wastes time.
-
As a last resort, consider enabling
`offload_inference`
. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
-
Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time.
efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference`
config option to enable all of these interventions
at once.
### Training
...
...
openfold/config.py
View file @
e49c6dc7
...
...
@@ -41,8 +41,19 @@ def enforce_config_constraints(config):
if
(
config
.
globals
.
use_flash
and
not
fa_is_installed
):
raise
ValueError
(
"use_flash requires that FlashAttention is installed"
)
if
(
config
.
globals
.
offload_inference
and
not
config
.
model
.
template
.
average_templates
):
config
.
model
.
template
.
offload_templates
=
True
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
,
long_sequence_inference
=
False
):
c
=
copy
.
deepcopy
(
config
)
# TRAINING PRESETS
if
name
==
"initial_training"
:
...
...
@@ -144,6 +155,16 @@ def model_config(name, train=False, low_prec=False):
else
:
raise
ValueError
(
"Invalid model name"
)
if
long_sequence_inference
:
assert
(
not
train
)
c
.
globals
.
offload_inference
=
True
c
.
globals
.
use_lma
=
True
c
.
globals
.
use_flash
=
False
c
.
model
.
template
.
offload_inference
=
True
c
.
model
.
template
.
template_pair_stack
.
tune_chunk_size
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
tune_chunk_size
=
False
c
.
model
.
evoformer_stack
.
tune_chunk_size
=
False
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
...
...
@@ -151,6 +172,7 @@ def model_config(name, train=False, low_prec=False):
c
.
globals
.
offload_inference
=
False
c
.
model
.
template
.
average_templates
=
False
c
.
model
.
template
.
offload_templates
=
False
if
low_prec
:
c
.
globals
.
eps
=
1e-4
# If we want exact numerical parity with the original, inf can't be
...
...
@@ -426,7 +448,8 @@ config = mlc.ConfigDict(
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates"
:
False
,
},
"extra_msa"
:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment