Commit a51b08cd authored by Jennifer's avatar Jennifer Committed by Jennifer Wei
Browse files

initial compatibility changes for upgrading multimer

parent af457b95
...@@ -3,6 +3,7 @@ channels: ...@@ -3,6 +3,7 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
- pytorch - pytorch
- nvidia
dependencies: dependencies:
- python=3.9 - python=3.9
- libgcc=7.2 - libgcc=7.2
...@@ -16,9 +17,9 @@ dependencies: ...@@ -16,9 +17,9 @@ dependencies:
- pandas - pandas
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests - requests
- scipy==1.7 - scipy
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==4.0 - typing-extensions
- wandb - wandb
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
...@@ -30,10 +31,11 @@ dependencies: ...@@ -30,10 +31,11 @@ dependencies:
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=2.1
- pytorch::pytorch-cuda=12.1
- pip: - pip:
- mpi4py==3.1.5 - mpi4py==3.1.5
- deepspeed==0.12.4 - deepspeed==0.12.4
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - flash-attn
...@@ -28,7 +28,7 @@ if ds4s_is_installed: ...@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
from flash_attn.bert_padding import unpad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask):
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func( out = flash_attn_varlen_kvpacked_func(
q, q,
kv_unpad, kv_unpad,
q_cu_seqlens, q_cu_seqlens,
......
...@@ -29,7 +29,7 @@ version_dependent_macros = [ ...@@ -29,7 +29,7 @@ version_dependent_macros = [
] ]
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-std=c++17',
'-maxrregcount=50', '-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment