Unverified Commit 2a0309a6 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Misc][Bugfix] FA3 support to ViT MHA layer (#12435)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 324960a9
......@@ -251,9 +251,28 @@ class MultiHeadAttention(nn.Module):
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
from vllm.vllm_flash_attn import flash_attn_func
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
from vllm.vllm_flash_attn import flash_attn_varlen_func
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)
out = flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
out = out.reshape(bsz, q_len, -1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment