Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
92f82dce
Commit
92f82dce
authored
Nov 14, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_dev_yiqa' into 'v0.5.4_dev'
使用groupgemm完成高吞吐模式适配。 See merge request OpenDAS/sglang!24
parents
a34b0d3d
54abdab4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
44 deletions
+47
-44
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+5
-7
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+36
-31
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
...ntization/compressed_tensors/compressed_tensors_marlin.py
+6
-6
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
92f82dce
...
@@ -563,7 +563,7 @@ class DeepEPMoE(EPMoE):
...
@@ -563,7 +563,7 @@ class DeepEPMoE(EPMoE):
)
)
def
forward_deepgemm_w4a8_marlin_contiguous
(
def
forward_deepgemm_w4a8_marlin_contiguous
(
self
,
self
,
dispatch_output
:
DeepEPNormalOutput
,
dispatch_output
:
DeepEPNormalOutput
,
):
):
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
...
@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
...
@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
if
all_tokens
<=
0
:
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
return
hidden_states
.
bfloat16
()
num_local_tokens
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_output
=
self
.
quant_method
.
apply_ep
(
expert_output
=
self
.
quant_method
.
apply_ep
(
x
=
hidden_states
,
x
=
hidden_states
,
w1
=
self
.
w13_weight
,
w1
=
self
.
w13_weight
,
...
@@ -591,10 +590,9 @@ class DeepEPMoE(EPMoE):
...
@@ -591,10 +590,9 @@ class DeepEPMoE(EPMoE):
w1_scale
=
self
.
w13_weight_scale
,
w1_scale
=
self
.
w13_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
# num_local_tokens=num_local_tokens,
)
)
return
expert_output
return
expert_output
def
forward_deepgemm_contiguous
(
def
forward_deepgemm_contiguous
(
self
,
self
,
...
@@ -807,11 +805,11 @@ class DeepEPMoE(EPMoE):
...
@@ -807,11 +805,11 @@ class DeepEPMoE(EPMoE):
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w4a8_gemm_nt_masked
(
m_grouped_w4a8_gemm_nt_masked
(
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
92f82dce
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
...
@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
):
):
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
if
(
#
if (
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
#
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and
not
get_moe_runner_backend
().
is_cutlass
()
#
and not get_moe_runner_backend().is_cutlass()
):
#
):
# TODO hard code 128 block quant,use fp8 communication
#
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
#
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states
,
#
hidden_states,
128
,
#
128,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
#
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
...
@@ -380,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -380,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
event
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
)
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
#
event.current_stream_wait() if self.async_finish else ()
if
isinstance
(
hidden_states
,
tuple
):
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
hidden_states_scale
=
hidden_states
hidden_states
,
hidden_states_scale
=
hidden_states
...
@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
num_tokens_per_expert
=
num_tokens_per_expert
,
previous_event
=
previous_event
,
previous_event
=
None
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
allocate_on_comm_stream
=
False
,
expert_alignment
=
128
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
else
1
,
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert
,
# num_recv_tokens_per_expert,
num_tokens_per_rank
=
num_tokens_per_rank
,
# num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert
=
num_tokens_per_expert
,
# num_tokens_per_expert=num_tokens_per_expert,
)
# )
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
return
(
return
(
recv_x
,
recv_x
,
...
@@ -495,7 +500,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -495,7 +500,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def
combine_b
(
self
,
output
,
previous_event
):
def
combine_b
(
self
,
output
,
previous_event
):
hidden_states
,
event
=
self
.
_combine_core
(
output
,
previous_event
)
hidden_states
,
event
=
self
.
_combine_core
(
output
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
#
event.current_stream_wait() if self.async_finish else ()
self
.
handle
=
None
self
.
handle
=
None
self
.
src2dst
=
None
self
.
src2dst
=
None
return
hidden_states
return
hidden_states
...
@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
combined_x
,
_
,
event
=
buffer
.
combine
(
combined_x
,
_
,
event
=
buffer
.
combine
(
x
,
x
,
self
.
handle
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
previous_event
=
previous_event
,
previous_event
=
None
,
allocate_on_comm_stream
=
previous_event
is
not
Non
e
,
allocate_on_comm_stream
=
Fals
e
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
)
)
return
combined_x
,
event
return
combined_x
,
event
...
@@ -536,7 +541,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -536,7 +541,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
"""
self
.
return_recv_hook
=
return_recv_hook
self
.
return_recv_hook
=
False
self
.
device_module
=
torch
.
get_device_module
()
self
.
device_module
=
torch
.
get_device_module
()
self
.
quant_config
=
{}
self
.
quant_config
=
{}
...
@@ -693,7 +698,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -693,7 +698,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx
=
topk_ids
,
topk_idx
=
topk_ids
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
handle
=
self
.
handle
,
handle
=
self
.
handle
,
zero_copy
=
False
,
zero_copy
=
False
,
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
)
)
...
@@ -703,7 +708,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -703,7 +708,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx
=
topk_ids
,
topk_idx
=
topk_ids
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
handle
=
self
.
handle
,
handle
=
self
.
handle
,
zero_copy
=
False
,
zero_copy
=
False
,
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
**
(
**
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
View file @
92f82dce
...
@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
...
@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list
:
list
[
str
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
packed_modules_mapping
:
Optional
[
dict
[
str
,
list
[
str
]]]
=
None
,
packed_modules_mapping
:
Optional
[
dict
[
str
,
list
[
str
]]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
target_scheme_map
,
target_scheme_map
,
...
@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
...
@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list
,
sparsity_ignore_list
,
kv_cache_scheme
,
kv_cache_scheme
,
config
,
config
,
packed_modules_mapping
,
packed_modules_mapping
,
)
)
@
classmethod
@
classmethod
def
override_quantization_method
(
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
...
@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
...
@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
prefix
:
str
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
# Avoid circular import
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
# Avoid circular import
#
from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
# Check if the layer is skipped for quantization.
# Check if the layer is skipped for quantization.
if
should_ignore_layer
(
prefix
,
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
ignore
=
self
.
ignore
,
...
@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
...
@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
layer
.
scheme
=
scheme
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
#
if isinstance(layer, RadixAttention):
if
isinstance
(
layer
,
RadixAttention
):
#
return CompressedTensorsKVCacheMethod(self)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMarlinMoEMethod
.
get_moe_method
(
self
,
layer
)
return
CompressedTensorsMarlinMoEMethod
.
get_moe_method
(
self
,
layer
)
return
None
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment