Unverified Commit a5095d62 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 6c2c467d
......@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
......@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(q, k, v, forward_batch)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
context_layer = self.attn(
q,
k,
v,
forward_batch,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
attn_output, _ = self.dense(context_layer)
return attn_output
......
......@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
LazyValue,
add_prefix,
......@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module):
def __init__(
self,
......@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
q,
k,
fused_set_kv_buffer_arg=(
_create_fused_set_kv_buffer_arg(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer(forward_batch)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
......@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
add_prefix,
is_cuda,
......@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
......@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state)
attn_output = self.attn(
*inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
def create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment