Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0efdb5c3
Unverified
Commit
0efdb5c3
authored
Sep 09, 2025
by
Wei
Committed by
GitHub
Sep 10, 2025
Browse files
[gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading (#24154)
Signed-off-by:
Wei Wei
<
wwei6@meta.com
>
parent
53b42f41
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
35 deletions
+146
-35
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+86
-17
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+60
-18
No files found.
tests/kernels/moe/test_mxfp4_moe.py
View file @
0efdb5c3
...
...
@@ -24,6 +24,8 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
next_positive_power_of_2
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
trtllm_fp4_block_scale_moe
)
from
flashinfer.fp4_quantization
import
nvfp4_block_scale_interleave
from
flashinfer.fused_moe.core
import
_maybe_get_cached_w2_permute_indices
@
dataclass
...
...
@@ -204,6 +206,7 @@ def tg_mxfp4_moe(
alpha
,
beta
,
limit
,
transpose_optimized
:
bool
=
False
,
)
->
torch
.
Tensor
:
sf_block_size
=
32
assert
(
w13_weight
.
dim
()
==
3
and
w13_weight
.
shape
[
0
]
==
num_experts
...
...
@@ -267,15 +270,78 @@ def tg_mxfp4_moe(
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
_cache_permute_indices
:
dict
[
torch
.
Size
,
torch
.
Tensor
]
=
{}
if
transpose_optimized
:
for
i
in
range
(
num_experts
):
# w13 weight shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm1_weights_shuffled
.
append
(
w13_weight
[
i
].
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w13_weight
.
device
)].
contiguous
())
# w13 scale shuffling
permute_sf_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm1_scales_shuffled
.
append
(
nvfp4_block_scale_interleave
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w13_weight_scale
.
device
)].
contiguous
()))
# w13 bias shuffling
permute_bias_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm1_bias_shuffled
.
append
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)[
permute_bias_indices
.
to
(
w13_bias
.
device
)].
contiguous
())
# w2 weight shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm2_weights_shuffled
.
append
(
w2_weight
[
i
].
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w2_weight
.
device
)].
contiguous
())
# w2 scale shuffling
permute_sf_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm2_scales_shuffled
.
append
(
nvfp4_block_scale_interleave
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w2_weight_scale
.
device
)].
contiguous
()))
# w2 bias shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
_cache_permute_indices
,
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm2_bias_shuffled
.
append
(
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)[
permute_indices
.
to
(
w2_bias
.
device
)].
contiguous
())
else
:
for
i
in
range
(
num_experts
):
gemm1_weights_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_weights_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
...
...
@@ -356,6 +422,7 @@ def check_accuracy(a, b, atol, rtol, percent):
@
pytest
.
mark
.
parametrize
(
"alpha,beta,limit"
,
[(
1.0
,
1.0
,
None
),
(
1.702
,
1.0
,
7.0
)])
@
pytest
.
mark
.
parametrize
(
"act_type"
,
[
'mxfp8'
,
'bf16'
])
@
pytest
.
mark
.
parametrize
(
"transpose_optimized"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
TRTLLM_GEN_MXFP4_AVAILABLE
,
reason
=
"nvidia gpu and compute capability sm100 is required for this test"
)
...
...
@@ -369,6 +436,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
beta
:
float
,
limit
:
Optional
[
float
],
act_type
:
str
,
transpose_optimized
:
bool
,
):
seed
=
42
torch
.
manual_seed
(
seed
)
...
...
@@ -470,6 +538,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
act_type
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
limit
=
limit
,
transpose_optimized
=
transpose_optimized
)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy
(
ref_result
,
tg_result
,
atol
=
0
,
rtol
=
0.3
,
percent
=
0.8
)
vllm/model_executor/layers/quantization/mxfp4.py
View file @
0efdb5c3
...
...
@@ -122,6 +122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results."
)
self
.
_cache_permute_indices
:
dict
[
torch
.
Size
,
torch
.
Tensor
]
=
{}
def
_should_use_marlin
(
self
):
if
envs
.
VLLM_MXFP4_USE_MARLIN
is
not
None
:
...
...
@@ -266,7 +267,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if
self
.
use_marlin
:
prepare_moe_fp4_layer_for_marlin
(
layer
)
elif
should_use_flashinfer_mxfp4
():
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
from
flashinfer.fp4_quantization
import
(
nvfp4_block_scale_interleave
)
from
flashinfer.fused_moe.core
import
(
_maybe_get_cached_w2_permute_indices
)
layer
.
gemm1_alpha
=
Parameter
(
torch
.
tensor
(
[
1.702
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
)
...
...
@@ -343,25 +347,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
for
i
in
range
(
self
.
num_experts
):
gemm1_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
# w13 weight shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm1_weights_mxfp4_shuffled
.
append
(
w13_weight
[
i
].
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w13_weight
.
device
)].
contiguous
())
# w13 scale shuffling
permute_sf_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm1_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_bias_shuffled
.
append
(
shuffle_matrix_a
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
))
gemm2_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
nvfp4_block_scale_interleave
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w13_weight_scale
.
device
)].
contiguous
()))
# w13 bias shuffling
permute_bias_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm1_bias_shuffled
.
append
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)[
permute_bias_indices
.
to
(
w13_bias
.
device
)].
contiguous
())
# w2 weight shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm2_weights_mxfp4_shuffled
.
append
(
w2_weight
[
i
].
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w2_weight
.
device
)].
contiguous
())
# w2 scale shuffling
permute_sf_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm2_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_bias_shuffled
.
append
(
shuffle_matrix_a
(
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
))
nvfp4_block_scale_interleave
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w2_weight_scale
.
device
)].
contiguous
()))
# w2 bias shuffling
permute_indices
=
_maybe_get_cached_w2_permute_indices
(
self
.
_cache_permute_indices
,
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm2_bias_shuffled
.
append
(
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)[
permute_indices
.
to
(
w2_bias
.
device
)].
contiguous
())
w13_weight
=
torch
.
stack
(
gemm1_weights_mxfp4_shuffled
)
w13_weight_scale
=
torch
.
stack
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment