Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d07ae9c4
Commit
d07ae9c4
authored
Jul 14, 2022
by
Gustaf Ahdritz
Browse files
Make FlashAttention installation optional
parent
1c279f90
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
5 deletions
+40
-5
environment.yml
environment.yml
+0
-1
openfold/config.py
openfold/config.py
+24
-1
openfold/model/primitives.py
openfold/model/primitives.py
+13
-3
scripts/install_third_party_dependencies.sh
scripts/install_third_party_dependencies.sh
+3
-0
No files found.
environment.yml
View file @
d07ae9c4
...
...
@@ -27,5 +27,4 @@ dependencies:
-
typing-extensions==3.10.0.2
-
pytorch_lightning==1.5.10
-
wandb==0.12.21
-
git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d
-
git+https://github.com/NVIDIA/dllogger.git
openfold/config.py
View file @
d07ae9c4
import
copy
import
importlib
import
ml_collections
as
mlc
...
...
@@ -36,6 +37,10 @@ def enforce_config_constraints(config):
if
(
s1_setting
and
s2_setting
):
raise
ValueError
(
f
"Only one of
{
s1
}
and
{
s2
}
may be set at a time"
)
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
config
.
globals
.
use_flash
and
not
fa_is_installed
):
raise
ValueError
(
"use_flash requires that FlashAttention is installed"
)
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
...
...
@@ -57,6 +62,24 @@ def model_config(name, train=False, low_prec=False):
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"finetuning_no_templ"
:
# AF2 Suppl. Table 4, "finetuning" setting
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
model
.
template
.
enabled
=
False
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
elif
name
==
"finetuning_no_templ_ptm"
:
# AF2 Suppl. Table 4, "finetuning" setting
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
model
.
template
.
enabled
=
False
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"model_1"
:
# AF2 Suppl. Table 5, Model 1.1.1
c
.
data
.
train
.
max_extra_msa
=
5120
...
...
@@ -324,7 +347,7 @@ config = mlc.ConfigDict(
"use_lma"
:
False
,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
"use_flash"
:
Tru
e
,
"use_flash"
:
Fals
e
,
"offload_inference"
:
False
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
...
...
openfold/model/primitives.py
View file @
d07ae9c4
...
...
@@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
importlib
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
deepspeed
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
fa_is_installed
):
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
...
...
@@ -643,6 +648,11 @@ def _lma(
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
if
(
not
fa_is_installed
):
raise
ValueError
(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims
=
q
.
shape
[:
-
3
]
no_heads
,
n
,
c
=
q
.
shape
[
-
3
:]
dtype
=
q
.
dtype
...
...
scripts/install_third_party_dependencies.sh
View file @
d07ae9c4
...
...
@@ -17,6 +17,9 @@ lib/conda/bin/python3 -m pip install nvidia-pyindex
conda
env
create
--name
=
${
ENV_NAME
}
-f
environment.yml
source
activate
${
ENV_NAME
}
echo
"Attempting to install FlashAttention"
pip
install
git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d
&&
echo
"Installation successful"
# Install DeepMind's OpenMM patch
OPENFOLD_DIR
=
$PWD
pushd
lib/conda/envs/
$ENV_NAME
/lib/python3.7/site-packages/
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment