Unverified Commit c11b34d5 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

rope xpu: fix missing argument 'fused_set_kv_buffer_arg' and replace native...

rope xpu: fix missing argument 'fused_set_kv_buffer_arg' and replace native with sgl_kernel_xpu impl (#12006)
parent 05ad28f2
......@@ -312,10 +312,20 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: make a wrapper, and XPU will implement this kernel later.
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
return self.forward_native(positions, query, key, offsets)
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for xpu implementation"
positions = torch.add(positions, offsets) if offsets is not None else positions
return torch.ops.sgl_kernel.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
class LinearScalingRotaryEmbedding(RotaryEmbedding):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment