"lib/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "8fb1421e5e4346f5218a11cd4adba4fcee878b08"
Commit b7f35dce authored by Christina Floristean's avatar Christina Floristean
Browse files

Use EvoformerFusedAttention directly to avoid all-zero bias term in column attention

parent 5aa54958
...@@ -23,7 +23,7 @@ if deepspeed_is_installed: ...@@ -23,7 +23,7 @@ if deepspeed_is_installed:
import deepspeed import deepspeed
if ds4s_is_installed: if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention from deepspeed.ops.deepspeed4science import EvoformerFusedAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
...@@ -661,18 +661,19 @@ def _deepspeed_evo_attn( ...@@ -661,18 +661,19 @@ def _deepspeed_evo_attn(
v = reshape_dims(v) v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases] biases = [reshape_dims(b) for b in biases]
biases.extend([None] * (2 - len(biases)))
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16 # DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference # Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]: if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16), inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x
k.to(dtype=torch.bfloat16), for x in (q, k, v, biases[0], biases[1])]
v.to(dtype=torch.bfloat16), o = EvoformerFusedAttention.apply(*inputs_bf16)
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype) o = o.to(dtype=orig_dtype)
else: else:
o = DS4Sci_EvoformerAttention(q, k, v, biases) o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1])
o = o.reshape(orig_shape) o = o.reshape(orig_shape)
return o return o
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment