Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c2bd094d
Unverified
Commit
c2bd094d
authored
Mar 23, 2025
by
xutizhou
Committed by
GitHub
Mar 22, 2025
Browse files
Optimize Permute Kernel in DeepEP (#4643)
Co-authored-by:
Cheng Wan
<
54331508+ch-wan@users.noreply.github.com
>
parent
f8f9244a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
230 deletions
+101
-230
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+47
-49
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+12
-15
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+39
-164
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
c2bd094d
...
...
@@ -17,52 +17,6 @@ if _is_cuda:
logger
=
logging
.
getLogger
(
__name__
)
@
triton
.
jit
def
compute_src2dst_triton_kernel
(
reorder_ids
,
src2dst
,
num_toks
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
dst_id
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
dst_id
<
num_toks
src_id
=
tl
.
load
(
reorder_ids
+
dst_id
,
mask
=
mask
)
tl
.
store
(
src2dst
+
src_id
,
dst_id
,
mask
=
mask
)
@
triton
.
jit
def
deepep_compute_src2dst_triton_kernel
(
reorder_ids
,
src2dst
,
num_toks
,
num_minus_one
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
dst_id
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
dst_id
<
num_toks
src_id
=
tl
.
load
(
reorder_ids
+
dst_id
,
mask
=
mask
)
num_invalid
=
tl
.
load
(
num_minus_one
)
tl
.
store
(
src2dst
+
src_id
,
dst_id
-
num_invalid
,
mask
=
mask
)
def
deepep_run_moe_deep_preprocess
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
):
reorder_topk_ids
,
reorder_ids
=
torch
.
sort
(
topk_ids
.
view
(
-
1
),
stable
=
True
)
seg_indptr
=
torch
.
zeros
(
num_experts
+
1
,
device
=
topk_ids
.
device
,
dtype
=
torch
.
int64
)
src2dst
=
torch
.
empty
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int32
)
# Find offet
expert_ids
=
torch
.
arange
(
num_experts
+
1
,
device
=
topk_ids
.
device
,
dtype
=
reorder_topk_ids
.
dtype
)
torch
.
searchsorted
(
reorder_topk_ids
,
expert_ids
,
out
=
seg_indptr
)
num_minus_one
=
seg_indptr
[
0
]
seg_indptr
=
seg_indptr
-
num_minus_one
BLOCK_SIZE
=
512
grid
=
(
triton
.
cdiv
(
topk_ids
.
numel
(),
BLOCK_SIZE
),)
deepep_compute_src2dst_triton_kernel
[
grid
](
reorder_ids
,
src2dst
,
topk_ids
.
numel
(),
num_minus_one
,
BLOCK_SIZE
)
reorder_topk_ids
=
reorder_topk_ids
[
num_minus_one
:]
return
reorder_topk_ids
,
src2dst
,
seg_indptr
@
triton
.
jit
def
deepep_permute_triton_kernel
(
input_ptr
,
...
...
@@ -85,14 +39,13 @@ def deepep_permute_triton_kernel(
for
start_offset
in
tl
.
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
offset
=
start_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset
<
hidden_size
in_data
=
tl
.
load
(
src_ptr
+
offset
,
mask
=
mask
).
to
(
tl
.
float32
)
in_data
=
tl
.
load
(
src_ptr
+
offset
,
mask
=
mask
).
to
(
OutDtype
)
for
idx
in
range
(
topk
):
dst_idx
=
tl
.
load
(
src2dst_ptr
+
idx
)
if
dst_idx
>=
0
:
dst_ptr
=
gateup_input_ptr
+
dst_idx
*
hidden_size
out_data
=
(
in_data
).
to
(
OutDtype
)
tl
.
store
(
dst_ptr
+
offset
,
out_data
,
mask
=
mask
)
tl
.
store
(
dst_ptr
+
offset
,
in_data
,
mask
=
mask
)
@
triton
.
jit
...
...
@@ -128,6 +81,51 @@ def deepep_post_reorder_triton_kernel(
tl
.
store
(
store_ptr
+
offset
,
sum_vec
,
mask
=
mask
)
@
triton
.
jit
def
compute_src2dst_triton_kernel
(
reorder_ids
,
src2dst
,
num_toks
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
dst_id
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
dst_id
<
num_toks
src_id
=
tl
.
load
(
reorder_ids
+
dst_id
,
mask
=
mask
)
tl
.
store
(
src2dst
+
src_id
,
dst_id
,
mask
=
mask
)
@
triton
.
jit
def
deepep_compute_src2dst_triton_kernel
(
reorder_ids
,
src2dst
,
num_toks
,
num_minus_one
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
dst_id
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
dst_id
<
num_toks
src_id
=
tl
.
load
(
reorder_ids
+
dst_id
,
mask
=
mask
)
num_invalid
=
tl
.
load
(
num_minus_one
)
tl
.
store
(
src2dst
+
src_id
,
dst_id
-
num_invalid
,
mask
=
mask
)
def
deepep_run_moe_deep_preprocess
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
):
reorder_topk_ids
,
reorder_ids
=
torch
.
sort
(
topk_ids
.
view
(
-
1
),
stable
=
True
)
seg_indptr
=
torch
.
empty
(
num_experts
+
1
,
device
=
topk_ids
.
device
,
dtype
=
torch
.
int64
)
src2dst
=
torch
.
empty
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int64
)
# Find offet
expert_ids
=
torch
.
arange
(
num_experts
+
1
,
device
=
topk_ids
.
device
,
dtype
=
reorder_topk_ids
.
dtype
)
torch
.
searchsorted
(
reorder_topk_ids
,
expert_ids
,
out
=
seg_indptr
)
num_minus_one
=
seg_indptr
[
0
]
seg_indptr
=
seg_indptr
-
num_minus_one
BLOCK_SIZE
=
512
grid
=
(
triton
.
cdiv
(
topk_ids
.
numel
(),
BLOCK_SIZE
),)
deepep_compute_src2dst_triton_kernel
[
grid
](
reorder_ids
,
src2dst
,
topk_ids
.
numel
(),
num_minus_one
,
BLOCK_SIZE
)
reorder_topk_ids
=
reorder_topk_ids
[
num_minus_one
:]
return
reorder_topk_ids
,
src2dst
,
seg_indptr
@
triton
.
jit
def
compute_seg_indptr_triton_kernel
(
reorder_topk_ids
,
seg_indptr
,
num_toks
):
expert
=
tl
.
program_id
(
0
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
c2bd094d
...
...
@@ -831,19 +831,23 @@ class DeepEPMoE(EPMoE):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if
True
:
# not forward_mode.is_decode():
return
self
.
forward_normal
(
hidden_states
,
tokens_per_expert
)
return
self
.
forward_normal
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
else
:
return
self
.
forward_deepgemm_masked
(
hidden_states
,
tokens_per_expert
)
return
self
.
forward_deepgemm_masked
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
):
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
...
...
@@ -851,15 +855,7 @@ class DeepEPMoE(EPMoE):
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
hidden_states
.
device
,
use_flashinfer
=
False
# TODO: use flashinfer
)
seg_indptr_cur_rank
=
torch
.
cat
(
[
torch
.
zeros
(
1
,
device
=
tokens_per_expert
.
device
,
dtype
=
tokens_per_expert
.
dtype
),
torch
.
cumsum
(
tokens_per_expert
,
dim
=
0
),
]
)
reorder_topk_ids
=
torch
.
repeat_interleave
(
tokens_per_expert
)
if
self
.
activation_scheme
==
"dynamic"
and
not
self
.
use_block_quant
:
max_value
=
(
torch
.
max
(
hidden_states
)
...
...
@@ -881,6 +877,7 @@ class DeepEPMoE(EPMoE):
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
if
hidden_states
.
shape
[
0
]
>
0
:
gateup_output
=
self
.
grouped_gemm_runner
(
a
=
hidden_states
,
...
...
@@ -888,7 +885,7 @@ class DeepEPMoE(EPMoE):
c
=
gateup_output
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
_cur_rank
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices_cur_rank
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
scale_a
=
self
.
w13_input_scale
,
...
...
@@ -946,7 +943,7 @@ class DeepEPMoE(EPMoE):
c
=
down_output
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
_cur_rank
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices_cur_rank
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
scale_a
=
self
.
w2_input_scale
,
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
c2bd094d
...
...
@@ -12,7 +12,6 @@ import torch
import
torch.distributed
as
dist
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
compute_src2dst_triton_kernel
,
deepep_permute_triton_kernel
,
deepep_post_reorder_triton_kernel
,
deepep_run_moe_deep_preprocess
,
...
...
@@ -86,90 +85,6 @@ def get_buffer_low_latency(
return
_buffer_low_latency
def
permute
(
tokens
,
routing_map
,
num_out_tokens
:
Optional
[
int
]
=
None
,
fused
:
bool
=
False
,
drop_and_pad
:
bool
=
False
,
):
"""
Copy from Megatron-Core moe for token permutation
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
"""
num_tokens
,
_
=
tokens
.
shape
num_experts
=
routing_map
.
shape
[
1
]
if
drop_and_pad
and
not
(
num_out_tokens
is
None
):
capacity
=
num_out_tokens
//
num_experts
assert
not
routing_map
.
requires_grad
routing_map
=
routing_map
.
to
(
dtype
=
torch
.
int8
).
T
.
contiguous
()
sorted_indices
=
routing_map
.
argsort
(
dim
=-
1
,
descending
=
True
,
stable
=
True
)[
:,
:
capacity
].
contiguous
()
sorted_indices
=
sorted_indices
.
view
(
-
1
)
else
:
routing_map
=
routing_map
.
bool
().
T
.
contiguous
()
token_indices
=
(
torch
.
arange
(
num_tokens
,
device
=
routing_map
.
device
)
.
unsqueeze
(
0
)
.
expand
(
num_experts
,
-
1
)
)
sorted_indices
=
token_indices
.
masked_select
(
routing_map
)
permuted_input
=
tokens
.
index_select
(
0
,
sorted_indices
)
return
permuted_input
,
sorted_indices
def
unpermute
(
permuted_tokens
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
restore_shape
:
torch
.
Size
,
probs
:
torch
.
Tensor
=
None
,
routing_map
:
torch
.
Tensor
=
None
,
fused
:
bool
=
False
,
drop_and_pad
:
bool
=
False
,
):
"""
Copy from Megatron-Core moe for token unpermutation
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
"""
_
,
hidden
=
restore_shape
if
probs
is
not
None
:
assert
routing_map
is
not
None
,
"Mask must be provided to permute the probs."
if
drop_and_pad
:
num_experts
=
routing_map
.
size
(
1
)
num_permuted_tokens
=
sorted_indices
.
size
(
0
)
capacity
=
num_permuted_tokens
//
num_experts
num_unpermuted_tokens
=
probs
.
size
(
0
)
probs_T_1D
=
probs
.
T
.
contiguous
().
view
(
-
1
)
indices_dim0
=
torch
.
arange
(
num_experts
,
device
=
routing_map
.
device
).
unsqueeze
(
-
1
)
indices_dim1
=
sorted_indices
.
view
(
num_experts
,
capacity
)
indices_1D
=
(
indices_dim0
*
num_unpermuted_tokens
+
indices_dim1
).
view
(
-
1
)
permuted_probs
=
probs_T_1D
.
index_select
(
0
,
indices_1D
)
else
:
permuted_probs
=
probs
.
T
.
contiguous
().
masked_select
(
routing_map
.
T
.
contiguous
()
)
permuted_tokens
=
permuted_tokens
*
permuted_probs
.
unsqueeze
(
-
1
)
output_tokens
=
torch
.
zeros
(
restore_shape
,
device
=
permuted_tokens
.
device
,
dtype
=
permuted_tokens
.
dtype
)
output_tokens
.
scatter_add_
(
0
,
sorted_indices
.
unsqueeze
(
1
).
expand
(
-
1
,
hidden
),
permuted_tokens
)
return
output_tokens
class
DeepEPDispatcher
:
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
...
...
@@ -228,16 +143,13 @@ class DeepEPDispatcher:
def
deepep_permute
(
self
,
topk_ids
,
hidden_states
,
num_experts
,
top_k
,
use_fp8_w8a8
,
use_block_quant
,
fp8_dtype
,
fp8_dtype
=
None
,
use_fp8_w8a8
=
False
,
use_block_quant
=
False
,
):
reorder_topk_ids
,
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_id
s
,
num_experts
self
.
topk_id
x
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
...
...
@@ -254,9 +166,9 @@ class DeepEPDispatcher:
hidden_states
,
gateup_input
,
src2dst
,
topk_id
s
,
self
.
topk_id
x
,
None
,
top
_
k
,
self
.
router_
topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
...
...
@@ -302,13 +214,21 @@ class DeepEPDispatcher:
)
)
self
.
recv_expert_count
=
recv_expert_count
tokens_per_expert
=
self
.
get_number_of_tokens_per_expert
()
self
.
handle
=
handle
self
.
topk_idx
=
topk_idx
self
.
topk_weights
=
topk_weights
if
hidden_states
.
shape
[
0
]
>
0
:
hidden_states
=
self
.
get_permuted_hidden_states_by_experts
(
hidden_states
)
return
hidden_states
,
topk_idx
,
topk_weights
,
tokens_per_expert
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
hidden_states
,
fp8_dtype
=
hidden_states
.
dtype
)
else
:
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
return
hidden_states
,
reorder_topk_ids
,
seg_indptr
def
dispatch_normal
(
self
,
...
...
@@ -427,10 +347,29 @@ class DeepEPDispatcher:
# Todo: enable low latency combine
if
True
:
# not forward_mode.is_decode():
if
hidden_states
.
shape
[
0
]
>
0
:
hidden_states
=
self
.
get_restored_hidden_states_by_experts
(
hidden_states
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
self
.
topk_idx
,
self
.
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
hidden_states
,
event
=
self
.
combine_normal
(
hidden_states
,
self
.
handle
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
combine_normal
(
output
,
self
.
handle
)
else
:
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
self
.
topk_idx
,
self
.
topk_weights
,
self
.
handle
...
...
@@ -467,67 +406,3 @@ class DeepEPDispatcher:
)
# hook()
return
combined_hidden_states
,
event_overlap
,
hook
def
_indices_to_multihot
(
self
,
indices
,
probs
):
batch_size
=
indices
.
shape
[
0
]
multihot_routing_map
=
torch
.
zeros
(
(
batch_size
,
self
.
num_local_experts
),
dtype
=
torch
.
long
,
device
=
indices
.
device
,
)
multihot_probs
=
torch
.
zeros
(
(
batch_size
,
self
.
num_local_experts
),
dtype
=
torch
.
float
,
device
=
indices
.
device
,
)
mask
=
indices
!=
-
1
valid_indices
=
indices
[
mask
]
row_indices
=
torch
.
arange
(
batch_size
,
device
=
indices
.
device
).
repeat_interleave
(
mask
.
sum
(
dim
=
1
)
)
multihot_routing_map
[
row_indices
,
valid_indices
]
=
1
multihot_probs
[
row_indices
,
valid_indices
]
=
probs
[
mask
]
return
multihot_routing_map
.
bool
(),
multihot_probs
def
get_dispached_metadata
(
self
)
->
torch
.
Tensor
:
return
self
.
topk_idx
,
self
.
topk_weights
def
get_number_of_tokens_per_expert
(
self
)
->
torch
.
Tensor
:
"""
Get the number of tokens per expert.
"""
return
self
.
tokens_per_expert
def
get_permuted_hidden_states_by_experts
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
dispatched_routing_map
,
self
.
topk_weights
=
self
.
_indices_to_multihot
(
self
.
topk_idx
,
self
.
topk_weights
)
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
hidden_states
,
self
.
reversed_mapping_for_combine
=
permute
(
hidden_states
,
self
.
dispatched_routing_map
,
num_out_tokens
=
self
.
tokens_per_expert
.
sum
(),
fused
=
self
.
permute_fusion
,
)
return
hidden_states
def
get_restored_hidden_states_by_experts
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
input_dtype
=
hidden_states
.
dtype
assert
(
self
.
topk_weights
.
dtype
==
torch
.
float32
),
"DeepEP only supports float32 probs"
hidden_states
=
unpermute
(
hidden_states
,
self
.
reversed_mapping_for_combine
,
restore_shape
=
self
.
hidden_shape_before_permute
,
routing_map
=
self
.
dispatched_routing_map
,
probs
=
self
.
topk_weights
,
fused
=
self
.
permute_fusion
,
)
return
hidden_states
.
to
(
input_dtype
)
python/sglang/srt/models/deepseek_v2.py
View file @
c2bd094d
...
...
@@ -294,7 +294,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias
=
self
.
correction_bias
,
)
if
self
.
tp_size
>
1
:
recv_hidden_states
,
topk_idx
,
topk_weights
,
tokens_per_expert
=
(
recv_hidden_states
,
reorder_topk_ids
,
seg_indptr
=
(
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
,
topk_idx
,
...
...
@@ -306,7 +306,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
(
self
.
experts
(
hidden_states
=
recv_hidden_states
,
tokens_per_expert
=
tokens_per_expert
,
reorder_topk_ids
=
reorder_topk_ids
,
seg_indptr
=
seg_indptr
,
forward_mode
=
forward_mode
,
)
*
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment