Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a071dc40
"vscode:/vscode.git/clone" did not exist on "dafe46710b5f4a93bfdceb84c7201d1c83423394"
Unverified
Commit
a071dc40
authored
May 21, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Tiny add stage assertions to DeepEPDispatcher to avoid misuse (#6467)
parent
a40aecc5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
1 deletion
+20
-1
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+20
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
a071dc40
import
logging
import
logging
from
dataclasses
import
dataclass
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.managers.expert_distribution
import
(
from
sglang.srt.managers.expert_distribution
import
(
...
@@ -18,7 +19,7 @@ try:
...
@@ -18,7 +19,7 @@ try:
except
ImportError
:
except
ImportError
:
use_deepep
=
False
use_deepep
=
False
from
enum
import
IntEnum
,
auto
from
enum
import
Enum
,
IntEnum
,
auto
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
)
)
@
dataclass
class
_Stage
(
Enum
):
INITIAL
=
auto
()
AFTER_DISPATCH_A
=
auto
()
AFTER_DISPATCH_B
=
auto
()
AFTER_COMBINE_A
=
auto
()
class
DeepEPDispatcher
:
class
DeepEPDispatcher
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -665,6 +674,8 @@ class DeepEPDispatcher:
...
@@ -665,6 +674,8 @@ class DeepEPDispatcher:
**
common_kwargs
,
**
common_kwargs
,
)
)
self
.
_stage
=
_Stage
.
INITIAL
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
dispatch_a
(
*
args
,
**
kwargs
)
self
.
dispatch_a
(
*
args
,
**
kwargs
)
ret
=
self
.
dispatch_b
()
ret
=
self
.
dispatch_b
()
...
@@ -677,6 +688,7 @@ class DeepEPDispatcher:
...
@@ -677,6 +688,7 @@ class DeepEPDispatcher:
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
=
None
,
forward_mode
:
ForwardMode
=
None
,
):
):
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
inner_state
=
self
.
_get_impl
(
forward_mode
).
dispatch_a
(
inner_state
=
self
.
_get_impl
(
forward_mode
).
dispatch_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
...
@@ -685,6 +697,7 @@ class DeepEPDispatcher:
...
@@ -685,6 +697,7 @@ class DeepEPDispatcher:
self
.
_dispatch_intermediate_state
=
forward_mode
,
inner_state
self
.
_dispatch_intermediate_state
=
forward_mode
,
inner_state
def
dispatch_b
(
self
):
def
dispatch_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
forward_mode
,
inner_state
=
self
.
_dispatch_intermediate_state
forward_mode
,
inner_state
=
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
(
forward_mode
).
dispatch_b
(
*
inner_state
)
return
self
.
_get_impl
(
forward_mode
).
dispatch_b
(
*
inner_state
)
...
@@ -701,6 +714,7 @@ class DeepEPDispatcher:
...
@@ -701,6 +714,7 @@ class DeepEPDispatcher:
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
):
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
(
forward_mode
).
combine_a
(
inner_state
=
self
.
_get_impl
(
forward_mode
).
combine_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
...
@@ -709,6 +723,7 @@ class DeepEPDispatcher:
...
@@ -709,6 +723,7 @@ class DeepEPDispatcher:
self
.
_combine_intermediate_state
=
forward_mode
,
inner_state
self
.
_combine_intermediate_state
=
forward_mode
,
inner_state
def
combine_b
(
self
):
def
combine_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
forward_mode
,
inner_state
=
self
.
_combine_intermediate_state
forward_mode
,
inner_state
=
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_mode
).
combine_b
(
*
inner_state
)
return
self
.
_get_impl
(
forward_mode
).
combine_b
(
*
inner_state
)
...
@@ -721,3 +736,7 @@ class DeepEPDispatcher:
...
@@ -721,3 +736,7 @@ class DeepEPDispatcher:
return
self
.
_low_latency_dispatcher
return
self
.
_low_latency_dispatcher
else
:
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
def
_update_stage
(
self
,
old_stage
,
new_stage
):
assert
self
.
_stage
==
old_stage
self
.
_stage
=
new_stage
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment