Unverified Commit b798e39f authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU][bugfix] fix rope for llama4 and deepseek (#25145)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 48eb8eba
......@@ -14,7 +14,7 @@ from .rocm_aiter_rope_ops import (
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
class RotaryEmbeddingBase(CustomOp):
"""Original rotary positional embedding."""
def __init__(
......@@ -86,6 +86,21 @@ class RotaryEmbedding(CustomOp):
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def forward_native(
self,
positions: torch.Tensor,
......
......@@ -7,7 +7,7 @@ import torch
from vllm.platforms import current_platform
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import (
rotate_gptj,
rotate_neox,
......@@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
......
......@@ -5,10 +5,10 @@ import math
import torch
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
......@@ -78,10 +78,3 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)
def forward_hip( # type: ignore[override]
self,
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)
......@@ -7,7 +7,7 @@ import torch
from vllm.triton_utils import tl, triton
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
......@@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
return x_t
class MRotaryEmbedding(RotaryEmbedding):
class MRotaryEmbedding(RotaryEmbeddingBase):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
......@@ -357,24 +357,6 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment