Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5190ba7f
Unverified
Commit
5190ba7f
authored
Aug 12, 2025
by
fzyzcjy
Committed by
GitHub
Aug 12, 2025
Browse files
Fuse two kernels of hidden states padding into quantization kernel (#9005)
Co-authored-by:
Qiaolin-Yu
<
liin1211@outlook.com
>
parent
5438886c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
9 deletions
+5
-9
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-8
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+4
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
5190ba7f
...
...
@@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module):
self
.
use_enable_flashinfer_mxfp4_moe
=
global_server_args_dict
.
get
(
"enable_flashinfer_mxfp4_moe"
,
False
)
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"mxfp4"
and
self
.
use_enable_flashinfer_mxfp4_moe
):
hidden_size
=
round_up
(
hidden_size
,
256
)
self
.
hidden_size
=
hidden_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
self
.
num_local_experts
,
...
...
@@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
origin_hidden_states_dim
:
hidden_states
=
torch
.
nn
.
functional
.
pad
(
hidden_states
,
(
0
,
self
.
hidden_size
-
origin_hidden_states_dim
),
mode
=
"constant"
,
value
=
0.0
,
)
assert
self
.
quant_method
is
not
None
if
self
.
moe_ep_size
>
1
and
not
self
.
enable_flashinfer_cutlass_moe
:
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
5190ba7f
...
...
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
,
alignment
=
self
.
hidden_size
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
top_k
,
router_logits
=
topk_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment