Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0a71d6b1
Commit
0a71d6b1
authored
Nov 26, 2025
by
yiqa
Committed by
lizhigong
Nov 26, 2025
Browse files
V0.5.4 dev yiqa
parent
078de197
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
48 deletions
+40
-48
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+27
-29
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+8
-15
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+5
-4
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
0a71d6b1
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
collections
import
defaultdict
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
torch
...
@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper(
...
@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper(
topk
,
topk
,
expect_m
expect_m
)
)
def
fuse_silu_mul_quant_ep_fake
(
def
fuse_silu_mul_quant_ep_fake
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE):
...
@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE):
if
all_tokens
<=
0
:
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
return
hidden_states
.
bfloat16
()
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
topk_idx
=
torch
.
where
(
topk_idx
==
-
1
,
self
.
num_experts
-
1
if
rank_expert_offset
==
0
else
0
,
topk_idx
+
rank_expert_offset
)
expert_output
=
self
.
quant_method
.
apply_ep
(
expert_output
=
self
.
quant_method
.
apply_ep
(
x
=
hidden_states
,
x
=
hidden_states
,
w1
=
self
.
w13_weight
,
w1
=
self
.
w13_weight
,
...
@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE):
...
@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE):
use_nn_moe
=
False
,
use_nn_moe
=
False
,
w1_scale
=
self
.
w13_weight_scale
,
w1_scale
=
self
.
w13_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
a1_scale
=
hidden_states_scale
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
)
)
return
expert_output
return
expert_output
...
@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE):
...
@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE):
active_experts
.
add
(
e
)
active_experts
.
add
(
e
)
token_expert_pos
[
t
]
=
lst
token_expert_pos
[
t
]
=
lst
active_experts
=
sorted
(
list
(
active_experts
))
if
not
active_experts
:
num_active
=
len
(
active_experts
)
if
num_active
==
0
:
return
hidden_states
.
bfloat16
()
return
hidden_states
.
bfloat16
()
active_experts
=
sorted
(
list
(
active_experts
))
counts
=
defaultdict
(
int
)
counts
=
defaultdict
(
int
)
for
t
in
range
(
M
):
for
t
in
range
(
M
):
...
@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE):
...
@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE):
per_expert_block
=
{}
per_expert_block
=
{}
for
e
in
active_experts
:
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
cnt
=
counts
[
e
]
if
cnt
<=
0
:
needed
=
((
cnt
+
255
)
//
256
)
*
256
# same as ceil(cnt/256)*256
per_expert_block
[
e
]
=
0
per_expert_block
[
e
]
=
max
(
256
,
needed
)
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
expert_slot_offset
=
{}
offset
=
0
offset
=
0
...
@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE):
...
@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE):
offset
+=
per_expert_block
[
e
]
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states_packed
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states_scale_packed
=
torch
.
empty
((
pad_M
,),
device
=
device
,
dtype
=
hidden_states_scale
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
...
@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE):
...
@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
per_expert_block
[
e
]:
raise
RuntimeError
(
f
"Internal error: expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
hidden_states_scale_packed
[
row
]
=
hidden_states_scale
[
t
]
m_indices
[
row
]
=
e
slot_counters
[
e
]
+=
1
slot_counters
[
e
]
+=
1
w
=
topk_weights
[
t
,
pos
].
to
(
device
=
device
)
# record weight (as float32 on device)
w
=
topk_weights
[
t
,
pos
]
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
#
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N
=
self
.
w13_weight
.
size
(
1
)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
hidden_states_packed
,
hidden_states_scale_packed
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
gateup_output
,
m_indices
,
m_indices
,
)
)
del
hidden_states_packed
,
hidden_states_scale_packed
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
...
@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE):
...
@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE):
)
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
for
(
row
,
w
)
in
token_row_weight_list
[
t
]:
if
not
pairs
:
result
[
t
].
addcmul_
(
down_output
[
row
].
float
(),
w
)
continue
acc
=
None
return
result
.
to
(
down_output
.
dtype
)
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
def
forward_deepgemm_contiguous
(
def
forward_deepgemm_contiguous
(
self
,
self
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
0a71d6b1
...
@@ -4,7 +4,6 @@ import logging
...
@@ -4,7 +4,6 @@ import logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
...
@@ -30,7 +29,7 @@ from sglang.srt.utils import (
...
@@ -30,7 +29,7 @@ from sglang.srt.utils import (
is_npu
,
is_npu
,
load_json_config
,
load_json_config
,
)
)
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# )
# )
hidden_states
=
per_token_quant_int8
(
hidden_states
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
...
@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment
=
1
,
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
)
if
self
.
quant_config
.
get
(
"quant_method"
)
==
"slimquant_w4a8_marlin"
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
num_recv_tokens_per_expert
,
recv_topk_ids
=
torch
.
where
(
num_tokens_per_rank
=
num_tokens_per_rank
,
recv_topk_ids
==
-
1
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
num_tokens_per_expert
=
num_tokens_per_expert
,
recv_topk_ids
+
self
.
rank_expert_offset
)
)
else
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
return
(
recv_x
,
recv_x
,
recv_topk_ids
,
recv_topk_ids
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
0a71d6b1
...
@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
):
self
.
moe_runner_config
=
moe_runner_config
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply
(
def
apply
(
self
,
self
,
...
@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# use_nn_moe: Optional[bool] = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# use_fused_gate: Optional[bool] = False,
# **_
# **_
# ) -> torch.Tensor:
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...
@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# use_nn_moe=use_nn_moe,
# )
# )
def
apply_ep
(
self
,
def
apply_ep
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment