Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ced3c07a
"src/vscode:/vscode.git/clone" did not exist on "1287822973e87fc8739a608b7296fe1797b1b074"
Unverified
Commit
ced3c07a
authored
May 30, 2025
by
Cheng Wan
Committed by
GitHub
May 30, 2025
Browse files
Support token-level quantization for EP MoE (#6782)
parent
f18b068f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
25 deletions
+89
-25
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+18
-2
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+71
-23
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
ced3c07a
...
...
@@ -178,6 +178,7 @@ def pre_reorder_triton_kernel(
topk
,
hidden_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
use_per_token_if_dynamic
:
tl
.
constexpr
,
):
OutDtype
=
gateup_input_ptr
.
dtype
.
element_ty
...
...
@@ -188,11 +189,15 @@ def pre_reorder_triton_kernel(
vec
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
a1_scales_ptr
is
not
None
and
use_per_token_if_dynamic
:
scale
=
1.0
/
tl
.
load
(
a1_scales_ptr
+
src_idx
)
for
idx
in
range
(
topk
):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
idx
)
if
expert_id
>=
start_expert_id
and
expert_id
<=
end_expert_id
:
if
a1_scales_ptr
is
not
None
:
scale
=
1.0
/
tl
.
load
(
a1_scales_ptr
+
expert_id
-
start_expert_id
)
if
not
use_per_token_if_dynamic
:
scale
=
1.0
/
tl
.
load
(
a1_scales_ptr
+
expert_id
-
start_expert_id
)
else
:
scale
=
1.0
...
...
@@ -558,6 +563,7 @@ def grouped_gemm_triton_kernel(
bs_stride_0
:
tl
.
constexpr
,
bs_stride_2
:
tl
.
constexpr
,
bs_stride_1
:
tl
.
constexpr
,
use_per_token_if_dynamic
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
...
...
@@ -621,7 +627,10 @@ def grouped_gemm_triton_kernel(
b_ptr
+=
BLOCK_SIZE_K
if
use_fp8_w8a8
and
not
(
group_k
>
0
and
group_n
>
0
):
scale_a_value
=
tl
.
load
(
scale_a
+
m_range_start
+
offs_am
[:,
None
])
if
use_per_token_if_dynamic
:
scale_a_value
=
tl
.
load
(
scale_a
+
(
m_range_start
+
offs_am
[:,
None
]))
else
:
scale_a_value
=
tl
.
load
(
scale_a
+
expert_id
)
scale_b_value
=
tl
.
load
(
scale_b
+
expert_id
)
accumulator
*=
scale_a_value
*
scale_b_value
...
...
@@ -658,6 +667,7 @@ def grouped_gemm_triton(
scale_b
:
torch
.
Tensor
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
c_dtype
=
None
,
use_per_token_if_dynamic
:
bool
=
True
,
):
assert
weight_column_major
==
True
# TODO: more
if
use_fp8_w8a8
and
block_shape
is
None
:
...
...
@@ -698,6 +708,11 @@ def grouped_gemm_triton(
triton
.
cdiv
(
b
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
)
if
use_fp8_w8a8
and
block_shape
is
None
and
use_per_token_if_dynamic
:
assert
(
scale_a
.
shape
[
0
]
==
a
.
shape
[
0
]
),
f
"scale_a.shape:
{
scale_a
.
shape
}
, a.shape:
{
a
.
shape
}
"
grouped_gemm_triton_kernel
[
grid
](
a
,
b
,
...
...
@@ -721,6 +736,7 @@ def grouped_gemm_triton(
scale_b
.
stride
(
0
)
if
scale_b
is
not
None
and
scale_b
.
ndim
>=
2
else
0
,
scale_b
.
stride
(
2
)
if
scale_b
is
not
None
and
scale_b
.
ndim
==
3
else
0
,
scale_b
.
stride
(
1
)
if
scale_b
is
not
None
and
scale_b
.
ndim
>=
2
else
0
,
use_per_token_if_dynamic
,
**
config
,
)
return
c
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
ced3c07a
...
...
@@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
(
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
dispose_tensor
,
is_hip
,
set_weight_attrs
...
...
@@ -65,10 +68,16 @@ logger = logging.getLogger(__name__)
class
GroupedGemmRunner
(
torch
.
nn
.
Module
):
flashinfer_gemm_warpper
=
None
def
__init__
(
self
,
device
,
use_flashinfer
:
bool
=
False
):
def
__init__
(
self
,
device
,
use_flashinfer
:
bool
=
False
,
use_per_token_if_dynamic
:
bool
=
True
,
):
super
().
__init__
()
self
.
device
=
device
self
.
use_flashinfer
=
use_flashinfer
self
.
use_per_token_if_dynamic
=
use_per_token_if_dynamic
if
self
.
use_flashinfer
and
GroupedGemmRunner
.
flashinfer_gemm_warpper
is
None
:
GroupedGemmRunner
.
_init_flashinfer_wrapper
(
device
)
...
...
@@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_b
,
block_shape
=
block_shape
,
c_dtype
=
c_dtype
,
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
)
return
c
...
...
@@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_per_token_if_dynamic
:
bool
=
True
,
):
super
().
__init__
()
...
...
@@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module):
self
.
custom_routing_function
=
custom_routing_function
self
.
activation
=
activation
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
use_per_token_if_dynamic
=
use_per_token_if_dynamic
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
...
...
@@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module):
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
hidden_states
.
device
,
use_flashinfer
=
False
,
# TODO: use flashinfer
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
)
topk_weights
,
topk_ids
=
select_experts
(
...
...
@@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module):
),
)
if
self
.
activation_scheme
==
"dynamic"
and
not
self
.
use_block_quant
:
max_value
=
(
torch
.
max
(
hidden_states
)
.
repeat
(
self
.
num_experts_per_partition
)
.
to
(
torch
.
float32
)
)
self
.
w13_input_scale
=
max_value
/
torch
.
finfo
(
self
.
fp8_dtype
).
max
if
self
.
use_per_token_if_dynamic
:
max_value
=
torch
.
max
(
hidden_states
,
dim
=
1
).
values
.
to
(
torch
.
float32
)
self
.
w13_input_scale
=
max_value
/
torch
.
finfo
(
self
.
fp8_dtype
).
max
else
:
max_value
=
(
torch
.
max
(
hidden_states
)
.
repeat
(
self
.
num_experts_per_partition
)
.
to
(
torch
.
float32
)
)
self
.
w13_input_scale
=
max_value
/
torch
.
finfo
(
self
.
fp8_dtype
).
max
# PreReorder
pre_reorder_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
...
...
@@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module):
self
.
top_k
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
)
dispose_tensor
(
hidden_states
)
if
(
self
.
activation_scheme
==
"dynamic"
and
not
self
.
use_block_quant
and
self
.
use_per_token_if_dynamic
):
scale
=
torch
.
empty
(
hidden_states_shape
[
0
]
*
self
.
top_k
,
device
=
hidden_states_device
,
dtype
=
torch
.
float32
,
)
scale
[
src2dst
]
=
(
self
.
w13_input_scale
.
unsqueeze
(
1
)
.
expand
(
hidden_states_shape
[
0
],
self
.
top_k
)
.
reshape
(
-
1
)
)
self
.
w13_input_scale
=
scale
seg_indptr_cur_rank
=
seg_indptr
[
self
.
start_expert_id
:
self
.
end_expert_id
+
2
]
weight_indices_cur_rank
=
torch
.
arange
(
0
,
...
...
@@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module):
del
gateup_input
# Act
down_input
=
torch
.
empty
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
(
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states_dtype
),
)
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
device
=
hidden_states_device
,
if
self
.
activation_scheme
==
"dynamic"
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
None
down_input
=
torch
.
empty
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states_dtype
,
)
else
:
down_input
=
torch
.
empty
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
(
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states_dtype
),
)
if
self
.
activation
==
"silu"
:
...
...
@@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module):
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
del
gateup_output
if
self
.
activation_scheme
==
"dynamic"
and
not
self
.
use_block_quant
:
if
self
.
use_per_token_if_dynamic
:
down_input
,
self
.
w2_input_scale
=
sglang_per_token_quant_fp8
(
down_input
)
else
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
device
=
hidden_states_device
,
)
# GroupGemm-1
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment