Unverified Commit a5095d62 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 6c2c467d
...@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None LoraConfig = None
...@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module): ...@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm: if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(
context_layer = self.attn(q, k, v, forward_batch) positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
context_layer = self.attn(
q,
k,
v,
forward_batch,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
......
...@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
...@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
return ans return ans
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module): class GptOssAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module): ...@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
q, q,
k, k,
fused_set_kv_buffer_arg=( fused_set_kv_buffer_arg=(
_create_fused_set_kv_buffer_arg( create_fused_set_kv_buffer_arg(
value=v, value=v,
layer=self.attn, layer=self.attn,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if _enable_fused_set_kv_buffer(forward_batch) if enable_fused_set_kv_buffer(forward_batch)
else None else None
), ),
) )
...@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module): ...@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn( attn_output = self.attn(
*inner_state, *inner_state,
sinks=self.sinks, sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch), save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
) )
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe ...@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
add_prefix, add_prefix,
is_cuda, is_cuda,
...@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module): ...@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
inner_state = q, k, v, forward_batch inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state return None, forward_batch, inner_state
...@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module): ...@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None: if inner_state is None:
return hidden_states return hidden_states
attn_output = self.attn(*inner_state) attn_output = self.attn(
*inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
def create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment