Commit d07ae9c4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make FlashAttention installation optional

parent 1c279f90
......@@ -27,5 +27,4 @@ dependencies:
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d
- git+https://github.com/NVIDIA/dllogger.git
import copy
import importlib
import ml_collections as mlc
......@@ -36,6 +37,10 @@ def enforce_config_constraints(config):
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
......@@ -57,6 +62,24 @@ def model_config(name, train=False, low_prec=False):
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120
......@@ -324,7 +347,7 @@ config = mlc.ConfigDict(
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
"use_flash": True,
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,
......
......@@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch
import torch.nn as nn
from scipy.stats import truncnorm
......@@ -643,6 +648,11 @@ def _lma(
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
......
......@@ -17,6 +17,9 @@ lib/conda/bin/python3 -m pip install nvidia-pyindex
conda env create --name=${ENV_NAME} -f environment.yml
source activate ${ENV_NAME}
echo "Attempting to install FlashAttention"
pip install git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d && echo "Installation successful"
# Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment