Commit d07ae9c4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make FlashAttention installation optional

parent 1c279f90
...@@ -27,5 +27,4 @@ dependencies: ...@@ -27,5 +27,4 @@ dependencies:
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10 - pytorch_lightning==1.5.10
- wandb==0.12.21 - wandb==0.12.21
- git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
import copy import copy
import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -36,6 +37,10 @@ def enforce_config_constraints(config): ...@@ -36,6 +37,10 @@ def enforce_config_constraints(config):
if(s1_setting and s2_setting): if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time") raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
def model_config(name, train=False, low_prec=False): def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
...@@ -57,6 +62,24 @@ def model_config(name, train=False, low_prec=False): ...@@ -57,6 +62,24 @@ def model_config(name, train=False, low_prec=False):
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
...@@ -324,7 +347,7 @@ config = mlc.ConfigDict( ...@@ -324,7 +347,7 @@ config = mlc.ConfigDict(
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with # Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. # use_lma.
"use_flash": True, "use_flash": False,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
......
...@@ -13,14 +13,19 @@ ...@@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed import deepspeed
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention fa_is_installed = importlib.util.find_spec("flash_attn") is not None
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
...@@ -643,6 +648,11 @@ def _lma( ...@@ -643,6 +648,11 @@ def _lma(
@torch.jit.ignore @torch.jit.ignore
def _flash_attn(q, k, v, kv_mask): def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3] batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:] no_heads, n, c = q.shape[-3:]
dtype = q.dtype dtype = q.dtype
......
...@@ -17,6 +17,9 @@ lib/conda/bin/python3 -m pip install nvidia-pyindex ...@@ -17,6 +17,9 @@ lib/conda/bin/python3 -m pip install nvidia-pyindex
conda env create --name=${ENV_NAME} -f environment.yml conda env create --name=${ENV_NAME} -f environment.yml
source activate ${ENV_NAME} source activate ${ENV_NAME}
echo "Attempting to install FlashAttention"
pip install git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d && echo "Installation successful"
# Install DeepMind's OpenMM patch # Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \ pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment