Commit 54d414e4 authored by Christina Floristean's avatar Christina Floristean
Browse files

Return to regular kernel usage

parent b7f35dce
...@@ -30,7 +30,7 @@ dependencies: ...@@ -30,7 +30,7 @@ dependencies:
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=1.12.*
- pip: - pip:
- deepspeed==0.12.2 - git+https://github.com/microsoft/DeepSpeed.git@4388a60 # Replace when version becomes available
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
...@@ -23,7 +23,7 @@ if deepspeed_is_installed: ...@@ -23,7 +23,7 @@ if deepspeed_is_installed:
import deepspeed import deepspeed
if ds4s_is_installed: if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import EvoformerFusedAttention from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
...@@ -661,19 +661,18 @@ def _deepspeed_evo_attn( ...@@ -661,19 +661,18 @@ def _deepspeed_evo_attn(
v = reshape_dims(v) v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases] biases = [reshape_dims(b) for b in biases]
biases.extend([None] * (2 - len(biases)))
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16 # DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference # Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]: if orig_dtype not in [torch.bfloat16, torch.float16]:
inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
for x in (q, k, v, biases[0], biases[1])] k.to(dtype=torch.bfloat16),
o = EvoformerFusedAttention.apply(*inputs_bf16) v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype) o = o.to(dtype=orig_dtype)
else: else:
o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1]) o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape) o = o.reshape(orig_shape)
return o return o
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment