Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e10fec9
Unverified
Commit
8e10fec9
authored
Apr 03, 2025
by
fzyzcjy
Committed by
GitHub
Apr 03, 2025
Browse files
Small refactor DeepEPMode to clean up code a bit (#4992)
parent
e8999b13
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
30 deletions
+44
-30
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+6
-10
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+11
-15
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+22
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8e10fec9
...
...
@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
is_cuda
,
is_hip
,
set_weight_attrs
_is_cuda
=
is_cuda
()
...
...
@@ -47,7 +47,6 @@ if _is_cuda:
else
:
from
vllm
import
_custom_ops
as
vllm_ops
logger
=
logging
.
getLogger
(
__name__
)
_is_hip
=
is_hip
()
...
...
@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
deepep_mode
:
str
=
"
auto
"
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
):
super
().
__init__
(
num_experts
,
...
...
@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
activation
,
)
self
.
deepep_mode
=
deepep_mode
if
self
.
deepep_mode
in
[
"
low_latency
"
,
"auto"
]
:
if
self
.
deepep_mode
.
enable_
low_latency
()
:
assert
use_deep_gemm
,
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
...
...
@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
expected_m
:
int
,
forward_mode
:
ForwardMode
,
):
if
self
.
deepep_mode
==
"normal"
or
(
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
):
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
forward_normal
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
elif
self
.
deepep_mode
==
"low_latency"
or
(
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
return
self
.
forward_deepgemm_masked
(
hidden_states
,
masked_m
,
expected_m
)
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
8e10fec9
from
sglang.srt.utils
import
DeepEPMode
try
:
from
deep_ep
import
Buffer
...
...
@@ -98,7 +100,7 @@ class DeepEPDispatcher:
num_local_experts
:
int
=
None
,
hidden_size
:
int
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
deepep_mode
:
str
=
"
auto
"
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
):
...
...
@@ -120,13 +122,13 @@ class DeepEPDispatcher:
self
.
deepep_mode
=
deepep_mode
self
.
handle
=
None
if
self
.
deepep_mode
in
[
"normal"
,
"auto"
]:
# for normal / auto mode
if
self
.
deepep_mode
.
enable_normal
():
self
.
buffer_normal
=
get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
if
self
.
deepep_mode
in
[
"low_latency"
,
"auto"
]:
# for low_latency / auto mode
if
self
.
deepep_mode
.
enable_low_latency
():
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
...
...
@@ -196,9 +198,8 @@ class DeepEPDispatcher:
)
expected_m
=
0
if
self
.
deepep_mode
==
"normal"
or
(
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
):
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
(
hidden_states
,
topk_idx
,
...
...
@@ -210,9 +211,7 @@ class DeepEPDispatcher:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
elif
self
.
deepep_mode
==
"low_latency"
or
(
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
...
...
@@ -354,9 +353,8 @@ class DeepEPDispatcher:
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
if
self
.
deepep_mode
==
"normal"
or
(
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
):
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
...
...
@@ -384,9 +382,7 @@ class DeepEPDispatcher:
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
elif
self
.
deepep_mode
==
"low_latency"
or
(
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
topk_idx
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8e10fec9
...
...
@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
...
...
@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]
]
,
)
if
config
.
n_shared_experts
is
not
None
:
...
...
@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module):
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]
]
,
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
...
...
python/sglang/srt/server_args.py
View file @
8e10fec9
...
...
@@ -20,7 +20,7 @@ import logging
import
os
import
random
import
tempfile
from
typing
import
List
,
Optional
from
typing
import
List
,
Literal
,
Optional
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.reasoning_parser
import
ReasoningParser
...
...
@@ -161,7 +161,7 @@ class ServerArgs:
enable_dp_attention
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
str
]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]
]
=
"auto"
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
cuda_graph_max_bs
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/utils.py
View file @
8e10fec9
...
...
@@ -37,6 +37,7 @@ import time
import
traceback
import
warnings
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
...
...
@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list):
]
else
:
return
[
nested_list
]
class
DeepEPMode
(
Enum
):
normal
=
"normal"
low_latency
=
"low_latency"
auto
=
"auto"
def
enable_normal
(
self
):
return
self
in
[
DeepEPMode
.
normal
,
DeepEPMode
.
auto
]
def
enable_low_latency
(
self
):
return
self
in
[
DeepEPMode
.
low_latency
,
DeepEPMode
.
auto
]
def
resolve
(
self
,
forward_mode
):
if
self
!=
DeepEPMode
.
auto
:
return
self
if
forward_mode
.
is_decode
():
return
DeepEPMode
.
low_latency
else
:
return
DeepEPMode
.
normal
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment