Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7211d743
Commit
7211d743
authored
Nov 21, 2025
by
yiqa
Committed by
lizhigong
Nov 21, 2025
Browse files
w8a8适配continues算子
parent
5533c538
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
124 additions
and
16 deletions
+124
-16
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+111
-4
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+13
-12
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
7211d743
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
...
...
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput
,
DispatchOutput
,
)
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_hip
=
is_hip
()
...
...
@@ -606,6 +606,8 @@ class DeepEPMoE(EPMoE):
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
self
.
use_w4a8_marlin
:
return
self
.
forward_deepgemm_w4a8_marlin_contiguous
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_contiguous
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output is not supported"
...
...
@@ -710,6 +712,111 @@ class DeepEPMoE(EPMoE):
)
return
expert_output
def
forward_groupgemm_w8a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
device
=
hidden_states
.
device
M
=
hidden_states
.
shape
[
0
]
K
=
hidden_states
.
shape
[
1
]
topk
=
topk_idx
.
shape
[
1
]
active_experts
=
set
()
token_expert_pos
=
[
None
]
*
M
for
t
in
range
(
M
):
lst
=
[]
for
pos
in
range
(
topk
):
e
=
int
(
topk_idx
[
t
,
pos
].
item
())
if
e
>=
0
:
lst
.
append
((
e
,
pos
))
active_experts
.
add
(
e
)
token_expert_pos
[
t
]
=
lst
active_experts
=
sorted
(
list
(
active_experts
))
num_active
=
len
(
active_experts
)
if
num_active
==
0
:
return
hidden_states
.
bfloat16
()
counts
=
defaultdict
(
int
)
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
counts
[
e
]
+=
1
per_expert_block
=
{}
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
if
cnt
<=
0
:
per_expert_block
[
e
]
=
0
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
offset
=
0
for
e
in
active_experts
:
expert_slot_offset
[
e
]
=
offset
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
token_row_weight_list
=
{
t
:
[]
for
t
in
range
(
M
)}
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
per_expert_block
[
e
]:
raise
RuntimeError
(
f
"Internal error: expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
slot_counters
[
e
]
+=
1
w
=
topk_weights
[
t
,
pos
].
to
(
device
=
device
)
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
m_indices
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a2_all
,
q_a2_scale
),
(
self
.
w2_weight
,
self
.
w2_weight_scale
),
down_output
,
m_indices
,
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
if
not
pairs
:
continue
acc
=
None
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
def
forward_deepgemm_contiguous
(
self
,
...
...
@@ -900,7 +1007,7 @@ class DeepEPMoE(EPMoE):
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8_triton_opt
(
hidden_states
,
masked_m
)
...
...
@@ -949,7 +1056,7 @@ class DeepEPMoE(EPMoE):
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8_triton_opt
(
hidden_states
,
masked_m
)
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
7211d743
...
...
@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
if
self
.
quant_config
.
get
(
"quant_method"
)
==
"slimquant_w4a8_marlin"
:
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
else
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
recv_x
,
recv_topk_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment