Commit a51b08cd authored by Jennifer's avatar Jennifer Committed by Jennifer Wei
Browse files

initial compatibility changes for upgrading multimer

parent af457b95
......@@ -3,6 +3,7 @@ channels:
- conda-forge
- bioconda
- pytorch
- nvidia
dependencies:
- python=3.9
- libgcc=7.2
......@@ -16,9 +17,9 @@ dependencies:
- pandas
- PyYAML==5.4.1
- requests
- scipy==1.7
- scipy
- tqdm==4.62.2
- typing-extensions==4.0
- typing-extensions
- wandb
- modelcif==0.7
- awscli
......@@ -30,10 +31,11 @@ dependencies:
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pytorch::pytorch=2.1
- pytorch::pytorch-cuda=12.1
- pip:
- mpi4py==3.1.5
- deepspeed==0.12.4
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
- flash-attn
......@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
import torch
import torch.nn as nn
......@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask):
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
out = flash_attn_varlen_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
......
......@@ -29,7 +29,7 @@ version_dependent_macros = [
]
extra_cuda_flags = [
'-std=c++14',
'-std=c++17',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment