Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b1e13e7c
Unverified
Commit
b1e13e7c
authored
Oct 28, 2025
by
Cheng Wan
Committed by
GitHub
Oct 28, 2025
Browse files
[hotfix] Incorrect CombineOverlapArgs in SBO (#12230)
parent
cc7b04a2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
11 deletions
+16
-11
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+0
-1
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+12
-9
python/sglang/srt/single_batch_overlap.py
python/sglang/srt/single_batch_overlap.py
+4
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
b1e13e7c
...
@@ -235,7 +235,6 @@ class DeepEPMoE(FusedMoE):
...
@@ -235,7 +235,6 @@ class DeepEPMoE(FusedMoE):
hidden_states
=
output
,
hidden_states
=
output
,
topk_ids
=
dispatch_output
.
topk_ids
,
topk_ids
=
dispatch_output
.
topk_ids
,
topk_weights
=
dispatch_output
.
topk_weights
,
topk_weights
=
dispatch_output
.
topk_weights
,
overlap_args
=
down_gemm_overlap_args
,
)
)
def
combine
(
def
combine
(
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
b1e13e7c
...
@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple):
...
@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple):
hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
@
property
@
property
def
format
(
self
)
->
CombineInputFormat
:
def
format
(
self
)
->
CombineInputFormat
:
...
@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple):
...
@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple):
hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
@
property
@
property
def
format
(
self
)
->
CombineInputFormat
:
def
format
(
self
)
->
CombineInputFormat
:
...
@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"
CombineOverlapArgs
"
],
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"
CombineOverlapArgs
"
],
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
...
@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"
CombineOverlapArgs
"
],
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
hidden_states
,
...
@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"
CombineOverlapArgs
"
],
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
...
@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
del
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
().
dispatch_b
(
*
inner_state
)
return
self
.
_get_impl
().
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
combine_input
:
CombineInput
)
->
Tuple
:
def
combine
(
self
.
combine_a
(
combine_input
)
self
,
combine_input
:
CombineInput
,
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
)
->
Tuple
:
self
.
combine_a
(
combine_input
,
overlap_args
)
ret
=
self
.
combine_b
()
ret
=
self
.
combine_b
()
return
ret
return
ret
def
combine_a
(
def
combine_a
(
self
,
self
,
combine_input
:
CombineInput
,
combine_input
:
CombineInput
,
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
hidden_states
,
topk_ids
,
topk_weights
,
overlap_args
=
combine_input
hidden_states
,
topk_ids
,
topk_weights
=
combine_input
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
().
combine_a
(
inner_state
=
self
.
_get_impl
().
combine_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
...
python/sglang/srt/single_batch_overlap.py
View file @
b1e13e7c
...
@@ -98,7 +98,10 @@ def execute_sbo(
...
@@ -98,7 +98,10 @@ def execute_sbo(
):
):
forward_shared_experts
()
forward_shared_experts
()
hidden_states
=
experts
.
dispatcher
.
combine
(
combine_input
=
combine_input
)
hidden_states
=
experts
.
dispatcher
.
combine
(
combine_input
=
combine_input
,
overlap_args
=
combine_overlap_args
,
)
return
hidden_states
return
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment