Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bfc3b3f7
Unverified
Commit
bfc3b3f7
authored
Oct 20, 2025
by
Cheng Wan
Committed by
GitHub
Oct 20, 2025
Browse files
[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)
parent
da5bde4d
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
18 deletions
+38
-18
python/sglang/srt/single_batch_overlap.py
python/sglang/srt/single_batch_overlap.py
+30
-18
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+2
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+2
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+4
-0
No files found.
python/sglang/srt/single_batch_overlap.py
View file @
bfc3b3f7
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
...
@@ -5,12 +21,12 @@ import torch
...
@@ -5,12 +21,12 @@ import torch
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_int_env_var
from
sglang.srt.utils
import
get_int_env_var
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.
ep_moe.layer
import
DeepEP
MoE
from
sglang.srt.layers.moe.
fused_moe_triton
import
Fused
MoE
class
SboFlags
:
class
SboFlags
:
...
@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
...
@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
def
execute_sbo
(
def
execute_sbo
(
forward_shared_experts
:
Callable
[[],
Any
],
forward_shared_experts
:
Callable
[[],
Any
],
experts
:
"DeepEP
MoE
"
,
experts
:
Fused
MoE
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_weights
:
torch
.
Tensor
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
forward_batch
:
ForwardBatch
,
alt_stream
:
Optional
=
None
,
disable_sbo
:
bool
=
False
,
disable_sbo
:
bool
=
False
,
):
):
dispatch_output
=
experts
.
dispatch
(
hidden_states
,
topk_idx
,
topk_weights
,
forward_batch
dispatch_output
=
experts
.
dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
)
combine_overlap_args
,
down_gemm_overlap_args
,
meta_overlap_args
=
(
combine_overlap_args
,
down_gemm_overlap_args
,
meta_overlap_args
=
(
_compute_overlap_args
(
dispatch_output
,
alt_stream
,
disable_sbo
=
disable_sbo
)
_compute_overlap_args
(
dispatch_output
,
alt_stream
,
disable_sbo
=
disable_sbo
)
)
)
hidden_states
=
experts
.
moe_impl
(
hidden_states
=
experts
.
run_moe_core
(
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
)
)
if
(
e
:
=
meta_overlap_args
.
get
(
"record_event_after_down"
))
is
not
None
:
if
(
e
:
=
meta_overlap_args
.
get
(
"record_event_after_down"
))
is
not
None
:
...
@@ -83,11 +98,10 @@ def execute_sbo(
...
@@ -83,11 +98,10 @@ def execute_sbo(
):
):
forward_shared_experts
()
forward_shared_experts
()
hidden_states
=
experts
.
combine
(
hidden_states
=
experts
.
dispatcher
.
combine
(
hidden_states
,
hidden_states
=
hidden_states
,
dispatch_output
.
topk_idx
,
topk_ids
=
dispatch_output
.
topk_ids
,
dispatch_output
.
topk_weights
,
topk_weights
=
dispatch_output
.
topk_weights
,
forward_batch
,
overlap_args
=
combine_overlap_args
,
overlap_args
=
combine_overlap_args
,
)
)
...
@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
...
@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
):
):
return
None
,
None
,
{}
return
None
,
None
,
{}
hidden_states
=
dispatch_output
.
hidden_states_fp8
hidden_states
=
dispatch_output
.
hidden_states
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
=
hidden_states
[
0
]
num_local_experts
,
num_tokens_static
,
hidden_dim
=
hidden_states
.
shape
num_local_experts
,
num_tokens_static
,
hidden_dim
=
hidden_states
.
shape
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
bfc3b3f7
...
@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
...
@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_is_extend_in_batch
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
...
@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_is_extend_in_batch
(
False
)
# Backup two fields, which will be modified in-place in `draft_forward`.
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
bfc3b3f7
...
@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
...
@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_is_extend_in_batch
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
...
@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_is_extend_in_batch
(
False
)
# Backup two fields, which will be modified in-place in `draft_forward`.
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
...
...
python/sglang/srt/two_batch_overlap.py
View file @
bfc3b3f7
...
@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
...
@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
def
combine_b
(
self
,
**
kwargs
):
def
combine_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine_b"
,
**
kwargs
)
return
self
.
_execute
(
"combine_b"
,
**
kwargs
)
def
set_quant_config
(
self
,
quant_config
:
dict
):
for
inner
in
self
.
_inners
:
inner
.
set_quant_config
(
quant_config
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment