Unverified Commit 4d7c1d53 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[Bugfix] Fix MRoPE dispatch on XPU (#24724)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 41f17bf2
...@@ -300,6 +300,15 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -300,6 +300,15 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key, offsets)
def forward_cpu( def forward_cpu(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment