Commit 48d8707e authored by zhuwenwen's avatar zhuwenwen
Browse files

use _apply_rotary_emb_torch for z100l&k100

parent 46dd30e7
......@@ -39,10 +39,11 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
from vllm.utils import SUPPORT_TC
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
if current_platform.is_rocm() and SUPPORT_TC:
from flash_attn.layers.rotary import apply_rotary_emb
......@@ -91,8 +92,11 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
if SUPPORT_TC:
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment