Commit 48d8707e authored by zhuwenwen's avatar zhuwenwen
Browse files

use _apply_rotary_emb_torch for z100l&k100

parent 46dd30e7
...@@ -39,10 +39,11 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -39,10 +39,11 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils import SUPPORT_TC
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm(): if current_platform.is_rocm() and SUPPORT_TC:
from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.layers.rotary import apply_rotary_emb
...@@ -91,8 +92,11 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ...@@ -91,8 +92,11 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
positional embeddings. positional embeddings.
""" """
if current_platform.is_cuda(): if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, if SUPPORT_TC:
not is_neox_style).squeeze(0) return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
else: else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment