Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5095d62
Unverified
Commit
a5095d62
authored
Sep 26, 2025
by
Yuan Luo
Committed by
GitHub
Sep 26, 2025
Browse files
Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
6c2c467d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
34 deletions
+105
-34
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+25
-2
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+7
-30
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+22
-2
python/sglang/srt/models/utils.py
python/sglang/srt/models/utils.py
+51
-0
No files found.
python/sglang/srt/models/bailing_moe.py
View file @
a5095d62
...
@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.utils
import
(
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
)
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
,
make_layers
LoraConfig
=
None
LoraConfig
=
None
...
@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
...
@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
if
self
.
use_qk_norm
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
context_layer
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
positions
,
q
,
k
,
fused_set_kv_buffer_arg
=
(
create_fused_set_kv_buffer_arg
(
value
=
v
,
layer
=
self
.
attn
,
forward_batch
=
forward_batch
,
)
if
enable_fused_set_kv_buffer
(
forward_batch
)
else
None
),
)
context_layer
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
not
enable_fused_set_kv_buffer
(
forward_batch
),
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
return
attn_output
...
...
python/sglang/srt/models/gpt_oss.py
View file @
a5095d62
...
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.utils
import
(
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
LazyValue
,
LazyValue
,
add_prefix
,
add_prefix
,
...
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
...
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
return
ans
return
ans
def
_enable_fused_set_kv_buffer
(
forward_batch
:
ForwardBatch
):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return
_is_cuda
and
forward_batch
.
token_to_kv_pool
.
dtype
==
torch
.
bfloat16
# TODO maybe move to a model-common utils
def
_create_fused_set_kv_buffer_arg
(
value
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
layer_id
=
layer
.
layer_id
token_to_kv_pool
=
forward_batch
.
token_to_kv_pool
k_buffer
=
token_to_kv_pool
.
get_key_buffer
(
layer_id
)
v_buffer
=
token_to_kv_pool
.
get_value_buffer
(
layer_id
)
return
FusedSetKVBufferArg
(
value
=
value
,
k_buffer
=
k_buffer
.
view
(
k_buffer
.
shape
[
0
],
-
1
),
v_buffer
=
v_buffer
.
view
(
v_buffer
.
shape
[
0
],
-
1
),
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
cache_loc
=
forward_batch
.
out_cache_loc
,
)
class
GptOssAttention
(
nn
.
Module
):
class
GptOssAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
...
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
q
,
q
,
k
,
k
,
fused_set_kv_buffer_arg
=
(
fused_set_kv_buffer_arg
=
(
_
create_fused_set_kv_buffer_arg
(
create_fused_set_kv_buffer_arg
(
value
=
v
,
value
=
v
,
layer
=
self
.
attn
,
layer
=
self
.
attn
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
if
_
enable_fused_set_kv_buffer
(
forward_batch
)
if
enable_fused_set_kv_buffer
(
forward_batch
)
else
None
else
None
),
),
)
)
...
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
...
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
*
inner_state
,
*
inner_state
,
sinks
=
self
.
sinks
,
sinks
=
self
.
sinks
,
save_kv_cache
=
not
_
enable_fused_set_kv_buffer
(
forward_batch
),
save_kv_cache
=
not
enable_fused_set_kv_buffer
(
forward_batch
),
)
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
a5095d62
...
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
...
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.utils
import
(
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
add_prefix
,
add_prefix
,
is_cuda
,
is_cuda
,
...
@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
,
fused_set_kv_buffer_arg
=
(
create_fused_set_kv_buffer_arg
(
value
=
v
,
layer
=
self
.
attn
,
forward_batch
=
forward_batch
,
)
if
enable_fused_set_kv_buffer
(
forward_batch
)
else
None
),
)
inner_state
=
q
,
k
,
v
,
forward_batch
inner_state
=
q
,
k
,
v
,
forward_batch
return
None
,
forward_batch
,
inner_state
return
None
,
forward_batch
,
inner_state
...
@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
if
inner_state
is
None
:
return
hidden_states
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
)
attn_output
=
self
.
attn
(
*
inner_state
,
save_kv_cache
=
not
enable_fused_set_kv_buffer
(
forward_batch
),
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
python/sglang/srt/models/utils.py
0 → 100644
View file @
a5095d62
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
FusedSetKVBufferArg
def
enable_fused_set_kv_buffer
(
forward_batch
:
ForwardBatch
):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return
_is_cuda
and
forward_batch
.
token_to_kv_pool
.
dtype
==
torch
.
bfloat16
def
create_fused_set_kv_buffer_arg
(
value
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
layer_id
=
layer
.
layer_id
token_to_kv_pool
=
forward_batch
.
token_to_kv_pool
k_buffer
=
token_to_kv_pool
.
get_key_buffer
(
layer_id
)
v_buffer
=
token_to_kv_pool
.
get_value_buffer
(
layer_id
)
return
FusedSetKVBufferArg
(
value
=
value
,
k_buffer
=
k_buffer
.
view
(
k_buffer
.
shape
[
0
],
-
1
),
v_buffer
=
v_buffer
.
view
(
v_buffer
.
shape
[
0
],
-
1
),
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
cache_loc
=
forward_batch
.
out_cache_loc
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment