Unverified Commit 68e1b711 authored by Wang, Yiting's avatar Wang, Yiting Committed by GitHub
Browse files

[XPU] Add deepseek_scaling_rope fused kernel (#36612)


Signed-off-by: default avataryitingw1 <yiting.wang@intel.com>
parent 0024f39a
...@@ -8,6 +8,7 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func ...@@ -8,6 +8,7 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"): ...@@ -54,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device) return torch.empty((M, N), dtype=input.dtype, device=input.device)
def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
assert key is not None
return torch.ops._xpu_C.deepseek_scaling_rope(
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
)
def _xpu_ops_deepseek_scaling_rope_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return query, key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class xpu_ops: class xpu_ops:
@staticmethod @staticmethod
def flash_attn_varlen_func( def flash_attn_varlen_func(
...@@ -402,3 +434,21 @@ class xpu_ops: ...@@ -402,3 +434,21 @@ class xpu_ops:
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = ( raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices topk_indices
) )
@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
# register all the custom ops here
direct_register_custom_op(
op_name="xpu_ops_deepseek_scaling_rope",
op_func=_xpu_ops_deepseek_scaling_rope_impl,
mutates_args=[],
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
xpu_ops.register_ops_once()
...@@ -152,6 +152,23 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -152,6 +152,23 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot key = key_rot
return query, key return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return torch.ops.vllm.xpu_ops_deepseek_scaling_rope(
positions,
query,
key,
offsets,
self._match_cos_sin_cache_dtype(query),
self.rotary_dim,
self.is_neox_style,
)
def forward_hip( def forward_hip(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment