Commit b7f35dce authored by Christina Floristean's avatar Christina Floristean
Browse files

Use EvoformerFusedAttention directly to avoid all-zero bias term in column attention

parent 5aa54958
......@@ -23,7 +23,7 @@ if deepspeed_is_installed:
import deepspeed
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
from deepspeed.ops.deepspeed4science import EvoformerFusedAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
......@@ -661,18 +661,19 @@ def _deepspeed_evo_attn(
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
biases.extend([None] * (2 - len(biases)))
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x
for x in (q, k, v, biases[0], biases[1])]
o = EvoformerFusedAttention.apply(*inputs_bf16)
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1])
o = o.reshape(orig_shape)
return o
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment