Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8f355853
Commit
8f355853
authored
Nov 14, 2025
by
yiqa
Browse files
使用groupgemm完成高吞吐模式适配
parent
842b423a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
41 deletions
+46
-41
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+5
-7
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+33
-28
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+8
-6
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8f355853
...
@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
...
@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
if
all_tokens
<=
0
:
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
return
hidden_states
.
bfloat16
()
num_local_tokens
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_output
=
self
.
quant_method
.
apply_ep
(
expert_output
=
self
.
quant_method
.
apply_ep
(
x
=
hidden_states
,
x
=
hidden_states
,
w1
=
self
.
w13_weight
,
w1
=
self
.
w13_weight
,
...
@@ -591,7 +590,6 @@ class DeepEPMoE(EPMoE):
...
@@ -591,7 +590,6 @@ class DeepEPMoE(EPMoE):
w1_scale
=
self
.
w13_weight_scale
,
w1_scale
=
self
.
w13_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
# num_local_tokens=num_local_tokens,
)
)
return
expert_output
return
expert_output
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
8f355853
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
...
@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
):
):
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
if
(
#
if (
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
#
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and
not
get_moe_runner_backend
().
is_cutlass
()
#
and not get_moe_runner_backend().is_cutlass()
):
#
):
# TODO hard code 128 block quant,use fp8 communication
#
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
#
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states
,
#
hidden_states,
128
,
#
128,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
#
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
...
@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
num_tokens_per_expert
=
num_tokens_per_expert
,
previous_event
=
previous_event
,
previous_event
=
None
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
allocate_on_comm_stream
=
False
,
expert_alignment
=
128
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
else
1
,
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert
,
# num_recv_tokens_per_expert,
num_tokens_per_rank
=
num_tokens_per_rank
,
# num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert
=
num_tokens_per_expert
,
# num_tokens_per_expert=num_tokens_per_expert,
)
# )
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
return
(
return
(
recv_x
,
recv_x
,
...
@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
combined_x
,
_
,
event
=
buffer
.
combine
(
combined_x
,
_
,
event
=
buffer
.
combine
(
x
,
x
,
self
.
handle
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
previous_event
=
previous_event
,
previous_event
=
None
,
allocate_on_comm_stream
=
previous_event
is
not
Non
e
,
allocate_on_comm_stream
=
Fals
e
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
)
)
return
combined_x
,
event
return
combined_x
,
event
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
8f355853
...
@@ -361,6 +361,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -361,6 +361,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment