Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e10fec9
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "7e33a017c086d2dbde3be0a546d40818bb1e9c16"
Unverified
Commit
8e10fec9
authored
Apr 03, 2025
by
fzyzcjy
Committed by
GitHub
Apr 03, 2025
Browse files
Small refactor DeepEPMode to clean up code a bit (#4992)
parent
e8999b13
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
30 deletions
+44
-30
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+6
-10
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+11
-15
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+22
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8e10fec9
...
@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
is_cuda
,
is_hip
,
set_weight_attrs
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -47,7 +47,6 @@ if _is_cuda:
...
@@ -47,7 +47,6 @@ if _is_cuda:
else
:
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
...
@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
deepep_mode
:
str
=
"
auto
"
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
):
):
super
().
__init__
(
super
().
__init__
(
num_experts
,
num_experts
,
...
@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
...
@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
activation
,
activation
,
)
)
self
.
deepep_mode
=
deepep_mode
self
.
deepep_mode
=
deepep_mode
if
self
.
deepep_mode
in
[
"
low_latency
"
,
"auto"
]
:
if
self
.
deepep_mode
.
enable_
low_latency
()
:
assert
use_deep_gemm
,
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
assert
use_deep_gemm
,
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
self
.
w13_weight_fp8
=
(
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
self
.
w13_weight
,
...
@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
...
@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
expected_m
:
int
,
expected_m
:
int
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
):
):
if
self
.
deepep_mode
==
"normal"
or
(
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
):
return
self
.
forward_normal
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
return
self
.
forward_normal
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
elif
self
.
deepep_mode
==
"low_latency"
or
(
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
return
self
.
forward_deepgemm_masked
(
hidden_states
,
masked_m
,
expected_m
)
return
self
.
forward_deepgemm_masked
(
hidden_states
,
masked_m
,
expected_m
)
else
:
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
8e10fec9
from
sglang.srt.utils
import
DeepEPMode
try
:
try
:
from
deep_ep
import
Buffer
from
deep_ep
import
Buffer
...
@@ -98,7 +100,7 @@ class DeepEPDispatcher:
...
@@ -98,7 +100,7 @@ class DeepEPDispatcher:
num_local_experts
:
int
=
None
,
num_local_experts
:
int
=
None
,
hidden_size
:
int
=
None
,
hidden_size
:
int
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
deepep_mode
:
str
=
"
auto
"
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
):
):
...
@@ -120,13 +122,13 @@ class DeepEPDispatcher:
...
@@ -120,13 +122,13 @@ class DeepEPDispatcher:
self
.
deepep_mode
=
deepep_mode
self
.
deepep_mode
=
deepep_mode
self
.
handle
=
None
self
.
handle
=
None
if
self
.
deepep_mode
in
[
"normal"
,
"auto"
]:
# for normal / auto mode
if
self
.
deepep_mode
.
enable_normal
():
self
.
buffer_normal
=
get_buffer_normal
(
self
.
buffer_normal
=
get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
)
self
.
async_finish
=
async_finish
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
self
.
src2dst
=
None
if
self
.
deepep_mode
in
[
"low_latency"
,
"auto"
]:
# for low_latency / auto mode
if
self
.
deepep_mode
.
enable_low_latency
():
"""
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
...
@@ -196,9 +198,8 @@ class DeepEPDispatcher:
...
@@ -196,9 +198,8 @@ class DeepEPDispatcher:
)
)
expected_m
=
0
expected_m
=
0
if
self
.
deepep_mode
==
"normal"
or
(
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
):
(
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
...
@@ -210,9 +211,7 @@ class DeepEPDispatcher:
...
@@ -210,9 +211,7 @@ class DeepEPDispatcher:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
)
elif
self
.
deepep_mode
==
"low_latency"
or
(
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
expected_m
=
(
expected_m
=
(
hidden_states
.
shape
[
0
]
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
self
.
buffer_low_latency
.
group_size
...
@@ -354,9 +353,8 @@ class DeepEPDispatcher:
...
@@ -354,9 +353,8 @@ class DeepEPDispatcher:
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
deepep_mode
==
"normal"
or
(
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
self
.
deepep_mode
==
"auto"
and
not
forward_mode
.
is_decode
()
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
):
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
output
=
torch
.
empty
(
...
@@ -384,9 +382,7 @@ class DeepEPDispatcher:
...
@@ -384,9 +382,7 @@ class DeepEPDispatcher:
output
,
output
,
)
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
elif
self
.
deepep_mode
==
"low_latency"
or
(
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
self
.
deepep_mode
==
"auto"
and
forward_mode
.
is_decode
()
):
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8e10fec9
...
@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
...
@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]
]
,
)
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
...
@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module):
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]
]
,
async_finish
=
True
,
# TODO
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
...
...
python/sglang/srt/server_args.py
View file @
8e10fec9
...
@@ -20,7 +20,7 @@ import logging
...
@@ -20,7 +20,7 @@ import logging
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
from
typing
import
List
,
Optional
from
typing
import
List
,
Literal
,
Optional
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
...
@@ -161,7 +161,7 @@ class ServerArgs:
...
@@ -161,7 +161,7 @@ class ServerArgs:
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
str
]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]
]
=
"auto"
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
torch_compile_max_bs
:
int
=
32
cuda_graph_max_bs
:
Optional
[
int
]
=
None
cuda_graph_max_bs
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/utils.py
View file @
8e10fec9
...
@@ -37,6 +37,7 @@ import time
...
@@ -37,6 +37,7 @@ import time
import
traceback
import
traceback
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
lru_cache
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
...
@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list):
...
@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list):
]
]
else
:
else
:
return
[
nested_list
]
return
[
nested_list
]
class
DeepEPMode
(
Enum
):
normal
=
"normal"
low_latency
=
"low_latency"
auto
=
"auto"
def
enable_normal
(
self
):
return
self
in
[
DeepEPMode
.
normal
,
DeepEPMode
.
auto
]
def
enable_low_latency
(
self
):
return
self
in
[
DeepEPMode
.
low_latency
,
DeepEPMode
.
auto
]
def
resolve
(
self
,
forward_mode
):
if
self
!=
DeepEPMode
.
auto
:
return
self
if
forward_mode
.
is_decode
():
return
DeepEPMode
.
low_latency
else
:
return
DeepEPMode
.
normal
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment