Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
587b4c6e
Unverified
Commit
587b4c6e
authored
Jun 25, 2025
by
yilian49
Committed by
GitHub
Jun 25, 2025
Browse files
EPLB support for MTP (#7510)
parent
7b9a174a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
4 deletions
+28
-4
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+21
-0
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+7
-4
No files found.
python/sglang/srt/managers/expert_distribution.py
View file @
587b4c6e
...
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
...
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
def
with_debug_name
(
self
,
debug_name
):
def
with_debug_name
(
self
,
debug_name
):
yield
yield
@
contextmanager
def
disable_this_region
(
self
):
yield
@
contextmanager
@
contextmanager
def
with_forward_pass
(
self
,
forward_pass_id
:
int
,
forward_batch
:
ForwardBatch
):
def
with_forward_pass
(
self
,
forward_pass_id
:
int
,
forward_batch
:
ForwardBatch
):
yield
yield
...
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
...
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
self
.
_expert_location_metadata
=
expert_location_metadata
self
.
_expert_location_metadata
=
expert_location_metadata
self
.
_recording
=
False
self
.
_recording
=
False
self
.
_disable_all
=
False
self
.
_current_forward_pass_id
=
Withable
()
self
.
_current_forward_pass_id
=
Withable
()
self
.
_current_layer_idx
=
Withable
()
self
.
_current_layer_idx
=
Withable
()
self
.
_current_debug_name
=
Withable
()
self
.
_current_debug_name
=
Withable
()
...
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
...
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
finally
:
finally
:
self
.
_on_forward_pass_end
(
forward_pass_id
)
self
.
_on_forward_pass_end
(
forward_pass_id
)
@
contextmanager
def
disable_this_region
(
self
):
"""Context manager to temporarily disable recording."""
previous_disable_all
=
self
.
_disable_all
self
.
_disable_all
=
True
try
:
yield
finally
:
self
.
_disable_all
=
previous_disable_all
def
_on_forward_pass_start
(
self
,
forward_batch
:
ForwardBatch
):
def
_on_forward_pass_start
(
self
,
forward_batch
:
ForwardBatch
):
if
not
self
.
_recording
:
if
not
self
.
_recording
:
return
return
...
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
...
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
)
)
def
_on_hook
(
self
,
hook_name
:
str
,
**
kwargs
):
def
_on_hook
(
self
,
hook_name
:
str
,
**
kwargs
):
if
self
.
_disable_all
:
return
if
not
(
self
.
_recording
or
torch
.
cuda
.
is_current_stream_capturing
()):
if
not
(
self
.
_recording
or
torch
.
cuda
.
is_current_stream_capturing
()):
return
return
gatherer
=
self
.
_single_pass_gatherers
[
gatherer
=
self
.
_single_pass_gatherers
[
...
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
...
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
topk_ids
=
topk_ids
.
flatten
()
topk_ids
=
topk_ids
.
flatten
()
mask
=
topk_ids
!=
-
1
mask
=
topk_ids
!=
-
1
assert
self
.
_data
[
layer_idx
,
:].
shape
==
topk_ids
.
shape
,
(
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self
.
_data
[
layer_idx
,
:].
scatter_add_
(
self
.
_data
[
layer_idx
,
:].
scatter_add_
(
dim
=
0
,
index
=
topk_ids
.
masked_fill
(
~
mask
,
0
).
long
(),
src
=
mask
.
int
()
dim
=
0
,
index
=
topk_ids
.
masked_fill
(
~
mask
,
0
).
long
(),
src
=
mask
.
int
()
)
)
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
587b4c6e
...
@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
...
@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
...
@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
buffer_size
=
2
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
...
@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
)
)
residual
=
None
residual
=
None
hidden_states
,
residual
=
self
.
decoder
(
with
get_global_expert_distribution_recorder
().
disable_this_region
():
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
hidden_states
,
residual
=
self
.
decoder
(
)
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
not
None
:
if
residual
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment