Unverified Commit c9ee7385 authored by Jiaqi Gu's avatar Jiaqi Gu Committed by GitHub
Browse files

Fuse writing KV buffer into rope kernel (part 2: srt) (#9014)


Co-authored-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
parent 1f9ec653
......@@ -119,7 +119,7 @@ jobs:
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.5
python3 -m pip --no-cache-dir install --user --force-reinstall genai-bench==0.0.1
python3 -m pip --no-cache-dir install sgl-kernel==0.3.3
python3 -m pip --no-cache-dir install sgl-kernel==0.3.4
- name: Build and install sgl-router
run: |
......
......@@ -58,7 +58,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.3",
"sgl-kernel==0.3.4",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
......
......@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
"0.3.3",
"0.3.4",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -222,6 +222,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
......@@ -231,8 +232,17 @@ class RotaryEmbedding(CustomOp):
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
# Compatible with old sgl-kernel
**(
dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
if fused_set_kv_buffer_arg is not None
else {}
),
)
else:
assert (
fused_set_kv_buffer_arg is None
), "save kv cache is not supported for vllm_rotary_embedding."
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
self.vllm_rotary_embedding(
positions,
......
......@@ -66,10 +66,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
class GptOssConfig(PretrainedConfig):
model_type = "gpt_oss"
......@@ -196,6 +201,32 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer():
return _is_cuda
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module):
def __init__(
self,
......@@ -303,7 +334,21 @@ class GptOssAttention(nn.Module):
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
_create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer()
else None
),
)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
......@@ -311,7 +356,11 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks)
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(),
)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment