Unverified Commit b9ce9a30 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[BugFix] Add fallback path in `apply_rotary_pos_emb_flashattn` for non-cuda platforms (#28447)


Signed-off-by: default avatarLin, Fanli <fanli.lin@intel.com>
parent 4ccffe56
......@@ -346,6 +346,13 @@ def apply_rotary_pos_emb_flashatt(
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment