Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53cbb488
Commit
53cbb488
authored
Nov 20, 2025
by
yiqa
Browse files
normal模式下适配w8a8
parent
3fc0ce15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
17 deletions
+28
-17
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+28
-17
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
53cbb488
...
@@ -2,7 +2,7 @@ from __future__ import annotations
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
torch
...
@@ -626,10 +626,11 @@ class DeepEPMoE(EPMoE):
...
@@ -626,10 +626,11 @@ class DeepEPMoE(EPMoE):
device
=
hidden_states
.
device
device
=
hidden_states
.
device
M
=
hidden_states
.
shape
[
0
]
M
=
hidden_states
.
shape
[
0
]
active_experts
=
set
()
K
=
hidden_states
.
shape
[
1
]
token_expert_pos
=
[
None
]
*
M
topk
=
topk_idx
.
shape
[
1
]
topk
=
topk_idx
.
shape
[
1
]
active_experts
=
set
()
token_expert_pos
=
[
None
]
*
M
for
t
in
range
(
M
):
for
t
in
range
(
M
):
lst
=
[]
lst
=
[]
for
pos
in
range
(
topk
):
for
pos
in
range
(
topk
):
...
@@ -644,13 +645,30 @@ class DeepEPMoE(EPMoE):
...
@@ -644,13 +645,30 @@ class DeepEPMoE(EPMoE):
if
num_active
==
0
:
if
num_active
==
0
:
return
hidden_states
.
bfloat16
()
return
hidden_states
.
bfloat16
()
block
=
256
counts
=
defaultdict
(
int
)
pad_M
=
block
*
num_active
for
t
in
range
(
M
):
K
=
hidden_states
.
shape
[
1
]
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
counts
[
e
]
+=
1
per_expert_block
=
{}
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
if
cnt
<=
0
:
per_expert_block
[
e
]
=
0
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
offset
=
0
for
e
in
active_experts
:
expert_slot_offset
[
e
]
=
offset
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
expert_slot_offset
=
{
e
:
i
*
block
for
i
,
e
in
enumerate
(
active_experts
)}
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
token_row_weight_list
=
{
t
:
[]
for
t
in
range
(
M
)}
token_row_weight_list
=
{
t
:
[]
for
t
in
range
(
M
)}
...
@@ -658,8 +676,8 @@ class DeepEPMoE(EPMoE):
...
@@ -658,8 +676,8 @@ class DeepEPMoE(EPMoE):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
block
:
if
slot
>=
per_expert_
block
[
e
]
:
raise
RuntimeError
(
f
"
Too many tokens f
or expert
{
e
}
(>block).
"
)
raise
RuntimeError
(
f
"
Internal err
or
:
expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
m_indices
[
row
]
=
int
(
e
)
...
@@ -672,16 +690,13 @@ class DeepEPMoE(EPMoE):
...
@@ -672,16 +690,13 @@ class DeepEPMoE(EPMoE):
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
N
=
self
.
w13_weight
.
size
(
1
)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
q_a1_all
,
q_a1_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
gateup_output
,
m_indices
,
m_indices
,
)
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a2_all
,
q_a2_scale
),
(
q_a2_all
,
q_a2_scale
),
...
@@ -690,7 +705,6 @@ class DeepEPMoE(EPMoE):
...
@@ -690,7 +705,6 @@ class DeepEPMoE(EPMoE):
m_indices
,
m_indices
,
)
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
pairs
=
token_row_weight_list
[
t
]
if
not
pairs
:
if
not
pairs
:
...
@@ -699,10 +713,7 @@ class DeepEPMoE(EPMoE):
...
@@ -699,10 +713,7 @@ class DeepEPMoE(EPMoE):
for
(
row
,
w
)
in
pairs
:
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
weighted
=
vec
*
w
if
acc
is
None
:
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
acc
=
weighted
else
:
acc
=
acc
+
weighted
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment