"fair_dev/testing/testing.py" did not exist on "49a198c99cdf61cf869ced2dc1e4e8b69926ceed"
Unverified Commit 07440f5f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix FusedSetKVBufferArg in RotaryEmbedding (#11003)

parent 9816989b
......@@ -27,7 +27,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
else:
FusedSetKVBufferArg = None
if _use_aiter:
from aiter.rotary_embedding import get_rope as aiter_get_rope
......@@ -146,8 +149,13 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
......@@ -176,12 +184,17 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward()."""
import os
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for npu implementation"
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
return self.forward_native(positions, query, key, offsets)
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
else:
rotary_mode = "half"
if self.is_neox_style:
......@@ -206,8 +219,12 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg=None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
......@@ -219,7 +236,9 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style,
)
else:
return self.forward_native(positions, query, key, offsets)
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
def forward_cuda(
self,
......@@ -227,7 +246,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
......
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
......@@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None:
positions = positions + offsets
......@@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment