Commit ffd8e40d authored by zhuwenwen's avatar zhuwenwen
Browse files

skip apply_rotary_emb from vllm_fa

parent daefd764
......@@ -77,7 +77,7 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() and not current_platform.is_rocm():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment