Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
678b3c99
Unverified
Commit
678b3c99
authored
Mar 25, 2026
by
Yongye Zhu
Committed by
GitHub
Mar 25, 2026
Browse files
[MoE Kernel] Flashinfer nvfp4 cutedsl moe kernel integration (#38050)
parent
bf4cc9ed
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
577 additions
and
248 deletions
+577
-248
tests/kernels/moe/test_cutedsl_moe.py
tests/kernels/moe/test_cutedsl_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py
...ayers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py
+353
-0
vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py
...ecutor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py
+64
-244
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+46
-2
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+95
-1
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+18
-0
No files found.
tests/kernels/moe/test_cutedsl_moe.py
View file @
678b3c99
...
@@ -17,7 +17,7 @@ from flashinfer import fp4_quantize
...
@@ -17,7 +17,7 @@ from flashinfer import fp4_quantize
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe
import
(
from
vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_
batched_
moe
import
(
# noqa: E501
flashinfer_cutedsl_moe_masked
,
flashinfer_cutedsl_moe_masked
,
)
)
from
vllm.utils.flashinfer
import
(
from
vllm.utils.flashinfer
import
(
...
...
vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py
0 → 100644
View file @
678b3c99
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4Dynamic
,
kNvfp4Static
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
flashinfer_cutedsl_grouped_gemm_nt_masked
,
has_flashinfer_cutedsl_grouped_gemm_nt_masked
,
scaled_fp4_grouped_quantize
,
silu_and_mul_scaled_nvfp4_experts_quantize
,
)
logger
=
init_logger
(
__name__
)
class
FlashInferCuteDSLBatchedExperts
(
mk
.
FusedMoEExpertsModular
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
):
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
quant_config
.
quant_dtype
==
"nvfp4"
,
(
"Only nvfp4 quantization are currently supported."
)
self
.
out_dtype
=
moe_config
.
in_dtype
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale_2
.
data
.
mul_
(
layer
.
w13_input_scale
)
layer
.
w2_weight_scale_2
.
data
.
mul_
(
layer
.
w2_input_scale
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_current_device
()
->
bool
:
p
=
current_platform
return
(
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
and
has_flashinfer_cutedsl_grouped_gemm_nt_masked
()
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kNvfp4Static
,
kNvfp4Dynamic
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
==
MoEActivation
.
SILU
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
False
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
return
TopKWeightAndReduceDelegate
()
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
MoEActivation
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
K_dim
=
K
*
2
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
else
K
output_shape
=
(
local_num_experts
,
M
,
K_dim
)
workspace2
=
(
local_num_experts
,
M
,
N
)
workspace1
=
output_shape
return
(
workspace1
,
workspace2
,
output_shape
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
MoEActivation
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
# Not used
workspace13
:
torch
.
Tensor
|
None
,
workspace2
:
torch
.
Tensor
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
|
None
,
):
assert
self
.
quant_dtype
==
"nvfp4"
,
(
"Only nvfp4 quantization are currently supported."
)
# Ensure w1_scale and w2_scale are not None before calling view
assert
self
.
w1_scale
is
not
None
and
self
.
w2_scale
is
not
None
,
(
"w1_scale and w2_scale must not be None for FlashInferExperts"
)
assert
expert_tokens_meta
is
not
None
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
assert
hidden_states
.
ndim
==
3
assert
self
.
w1_scale
.
ndim
==
3
assert
self
.
w2_scale
.
ndim
==
3
input_global_scale
=
(
None
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
else
self
.
a1_gscale
)
flashinfer_hidden_states
=
(
(
hidden_states
,
a1q_scale
)
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
else
hidden_states
)
flashinfer_cutedsl_moe_masked
(
hidden_states
=
flashinfer_hidden_states
,
input_global_scale
=
input_global_scale
,
w1
=
w1
,
w1_blockscale
=
self
.
w1_scale
,
w1_alpha
=
self
.
g1_alphas
,
w2
=
w2
,
a2_global_scale
=
self
.
a2_gscale
,
w2_blockscale
=
self
.
w2_scale
,
w2_alpha
=
self
.
g2_alphas
,
masked_m
=
expert_num_tokens
,
workspace
=
workspace2
,
out
=
output
,
)
def
get_cute_dtype
(
input
:
torch
.
Tensor
)
->
str
:
if
input
.
dtype
==
torch
.
bfloat16
:
return
"bfloat16"
elif
input
.
dtype
==
torch
.
float16
:
return
"float16"
elif
input
.
dtype
==
torch
.
float32
:
return
"float32"
else
:
raise
ValueError
(
f
"Unsupported cute dtype
{
input
.
dtype
}
"
)
def
flashinfer_cutedsl_moe_masked
(
hidden_states
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
input_global_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alpha
,
w2
:
torch
.
Tensor
,
a2_global_scale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alpha
,
masked_m
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert
w1
.
dtype
==
torch
.
uint8
,
f
"w1 must be uint8, got
{
w1
.
dtype
}
"
assert
w1_blockscale
.
dtype
==
torch
.
float8_e4m3fn
,
(
f
"w1_blockscale must be float8_e4m3fn, got
{
w1_blockscale
.
dtype
}
"
)
assert
w1_alpha
.
dtype
==
torch
.
float32
,
(
f
"w1_alpha must be float32, got
{
w1_alpha
.
dtype
}
"
)
assert
w2
.
dtype
==
torch
.
uint8
,
f
"w2 must be uint8, got
{
w2
.
dtype
}
"
assert
a2_global_scale
.
dtype
==
torch
.
float32
,
(
f
"a2_global_scale must be float32, got
{
a2_global_scale
.
dtype
}
"
)
assert
w2_blockscale
.
dtype
==
torch
.
float8_e4m3fn
,
(
f
"w2_blockscale must be float8_e4m3fn, got
{
w2_blockscale
.
dtype
}
"
)
assert
w2_alpha
.
dtype
==
torch
.
float32
,
(
f
"w2_alpha must be float32, got
{
w2_alpha
.
dtype
}
"
)
# === Assertions on shapes ===
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
if
isinstance
(
hidden_states
,
tuple
):
assert
input_global_scale
is
None
,
(
"input_global_scale is needed when input needs quant"
)
aq
=
hidden_states
[
0
].
view
(
torch
.
uint8
)
aq_sf
=
hidden_states
[
1
].
view
(
torch
.
float8_e4m3fn
)
# m, k_by_2, num_experts = aq.shape
num_experts
,
m
,
k_by_2
=
aq
.
shape
k
=
k_by_2
*
2
aq
=
aq
.
permute
(
1
,
2
,
0
)
else
:
num_experts
,
m
,
k
=
hidden_states
.
shape
assert
input_global_scale
.
dtype
==
torch
.
float32
,
(
f
"input_global_scale must be float32, got
{
input_global_scale
.
dtype
}
"
)
assert
input_global_scale
.
shape
==
(
num_experts
,),
(
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
)
aq
,
aq_sf
=
scaled_fp4_grouped_quantize
(
hidden_states
,
masked_m
,
input_global_scale
,
)
assert
w1
.
shape
[
-
2
]
==
2
*
n
,
f
"w1 last-2 dim must be 2*n, got
{
w1
.
shape
}
"
assert
w1
.
shape
[
-
1
]
*
2
==
k
,
(
f
"w1 last dim * 2 must equal k, got
{
w1
.
shape
[
-
1
]
}
vs k=
{
k
}
"
)
assert
w2
.
shape
[
-
2
:]
==
(
k
,
n
//
2
,
),
f
"w2 shape mismatch, got
{
w2
.
shape
[
-
2
:]
}
, expected
{
(
k
,
n
//
2
)
}
"
assert
w1_alpha
.
shape
==
(
num_experts
,),
(
f
"w1_alpha must be (l,), got
{
w1_alpha
.
shape
}
"
)
assert
a2_global_scale
.
shape
==
(
num_experts
,),
(
f
"a2_global_scale must be (l,), got
{
a2_global_scale
.
shape
}
"
)
assert
w2_alpha
.
shape
==
(
num_experts
,),
(
f
"w2_alpha must be (l,), got
{
w2_alpha
.
shape
}
"
)
workspace
=
workspace
.
permute
(
1
,
2
,
0
)
# requirement of kernel
sf_vec_size
=
16
assert
aq_sf
.
dtype
==
torch
.
float8_e4m3fn
assert
aq
.
dtype
==
torch
.
uint8
ab_dtype
=
"float4_e2m1fn"
sf_dtype
=
"float8_e4m3fn"
if
isinstance
(
hidden_states
,
tuple
):
c_dtype
=
"bfloat16"
else
:
c_dtype
=
get_cute_dtype
(
hidden_states
)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked
(
(
aq
,
aq_sf
),
(
w1
.
permute
(
1
,
2
,
0
),
w1_blockscale
),
workspace
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w1_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w1_alpha
),
)
# in logical [m, n, l]
# SILU and quantization
diq
,
diq_sf
=
silu_and_mul_scaled_nvfp4_experts_quantize
(
workspace
.
permute
(
2
,
0
,
1
),
masked_m
,
a2_global_scale
,
)
# Gemm2
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked
(
(
diq
,
diq_sf
),
(
w2
.
permute
(
1
,
2
,
0
),
w2_blockscale
),
out
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w2_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w2_alpha
),
)
# in logical [m, k, l]
out
=
out
.
permute
(
2
,
0
,
1
)
vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py
View file @
678b3c99
...
@@ -4,8 +4,6 @@
...
@@ -4,8 +4,6 @@
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEConfig
,
...
@@ -13,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -13,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduce
Delegate
,
TopKWeightAndReduce
NoOP
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
QuantKey
,
...
@@ -22,33 +20,42 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -22,33 +20,42 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
from
vllm.utils.flashinfer
import
(
flashinfer_cutedsl_grouped_gemm_nt_masked
,
flashinfer_cute_dsl_fused_moe_nvfp4
,
has_flashinfer_cutedsl_grouped_gemm_nt_masked
,
has_flashinfer_cutedsl_moe_nvfp4
,
scaled_fp4_grouped_quantize
,
silu_and_mul_scaled_nvfp4_experts_quantize
,
)
)
logger
=
init_logger
(
__name__
)
class
FlashInferCuteDSLExperts
(
mk
.
FusedMoEExpertsModular
):
class
FlashInferCuteDSLExperts
(
mk
.
FusedMoEExpertsModular
):
"""
CuteDSL NvFP4 MoE experts using the FlashInfer functional API.
Uses Standard activation format (non-batched). The kernel handles
routing, expert computation, and reduction internally.
Supports expert parallelism natively.
"""
def
__init__
(
def
__init__
(
self
,
self
,
moe_config
:
FusedMoEConfig
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
):
):
super
().
__init__
(
super
().
__init__
(
moe_config
=
moe_config
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
)
assert
quant_config
.
quant_dtype
==
"nvfp4"
,
(
assert
quant_config
.
quant_dtype
==
"nvfp4"
,
(
"Only nvfp4 quantization
are
currently supported."
"Only nvfp4 quantization
is
currently supported."
)
)
self
.
out_dtype
=
moe_config
.
in_dtype
self
.
out_dtype
=
moe_config
.
in_dtype
self
.
hidden_dim
=
moe_config
.
hidden_dim
self
.
intermediate_size_per_partition
=
(
moe_config
.
intermediate_size_per_partition
)
self
.
topk
=
moe_config
.
experts_per_token
self
.
local_num_experts
=
moe_config
.
num_local_experts
self
.
global_num_experts
=
moe_config
.
num_experts
self
.
ep_rank
=
moe_config
.
moe_parallel_config
.
ep_rank
self
.
local_expert_offset
=
self
.
ep_rank
*
self
.
local_num_experts
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale_2
.
data
.
mul_
(
layer
.
w13_input_scale
)
layer
.
w13_weight_scale_2
.
data
.
mul_
(
layer
.
w13_input_scale
)
...
@@ -56,7 +63,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
...
@@ -56,7 +63,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
@
staticmethod
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
@
staticmethod
def
_supports_current_device
()
->
bool
:
def
_supports_current_device
()
->
bool
:
...
@@ -64,7 +71,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
...
@@ -64,7 +71,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
return
(
return
(
p
.
is_cuda
()
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
and
p
.
is_device_capability_family
(
100
)
and
has_flashinfer_cutedsl_
grouped_gemm_nt_masked
()
and
has_flashinfer_cutedsl_
moe_nvfp4
()
)
)
@
staticmethod
@
staticmethod
...
@@ -86,15 +93,16 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
...
@@ -86,15 +93,16 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
return
activation
==
MoEActivation
.
SILU
return
activation
==
MoEActivation
.
SILU
@
staticmethod
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
bool
:
return
True
return
True
def
supports_expert_map
(
self
)
->
bool
:
def
supports_expert_map
(
self
)
->
bool
:
return
False
return
False
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
return
TopKWeightAndReduceNoOP
()
return
TopKWeightAndReduceDelegate
()
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
...
@@ -107,29 +115,12 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
...
@@ -107,29 +115,12 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
MoEActivation
,
activation
:
MoEActivation
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# We use global_num_experts due to how moe_align_block_size handles
workspace1
=
(
0
,)
# expert_maps.
workspace2
=
(
0
,)
"""
# K is packed (K//2 for uint8), so output uses hidden_dim.
Compute the shapes for the temporary and final outputs of the two gemms
assert
self
.
hidden_dim
==
K
*
2
and activation in the fused expert function. Since the gemms are
output
=
(
M
,
self
.
hidden_dim
)
independent, the workspace for the first gemm can be shared with the
return
(
workspace1
,
workspace2
,
output
)
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
K_dim
=
K
*
2
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
else
K
output_shape
=
(
local_num_experts
,
M
,
K_dim
)
workspace2
=
(
local_num_experts
,
M
,
N
)
workspace1
=
output_shape
return
(
workspace1
,
workspace2
,
output_shape
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -143,210 +134,39 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
...
@@ -143,210 +134,39 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
global_num_experts
:
int
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
# Not used
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
|
None
,
workspace2
:
torch
.
Tensor
|
None
,
workspace2
:
torch
.
Tensor
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
|
None
,
apply_router_weight_on_input
:
bool
|
None
,
):
):
assert
self
.
quant_dtype
==
"nvfp4"
,
(
assert
self
.
quant_dtype
==
"nvfp4"
"Only nvfp4 quantization are currently supported."
assert
a1q_scale
is
not
None
)
assert
self
.
w1_scale
is
not
None
# Ensure w1_scale and w2_scale are not None before calling view
assert
self
.
w2_scale
is
not
None
assert
self
.
w1_scale
is
not
None
and
self
.
w2_scale
is
not
None
,
(
"w1_scale and w2_scale must not be None for FlashInferExperts"
# a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize.
)
# The functional API expects x_sf with trailing dim: (M, K//16, 1).
assert
expert_tokens_meta
is
not
None
x_sf
=
a1q_scale
.
unsqueeze
(
-
1
)
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
assert
hidden_states
.
ndim
==
3
from
vllm.utils.flashinfer
import
_is_fi_autotuning
,
autotune
assert
self
.
w1_scale
.
ndim
==
3
assert
self
.
w2_scale
.
ndim
==
3
with
autotune
(
_is_fi_autotuning
):
flashinfer_cute_dsl_fused_moe_nvfp4
(
input_global_scale
=
(
x
=
hidden_states
,
None
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
else
self
.
a1_gscale
x_sf
=
x_sf
,
)
token_selected_experts
=
topk_ids
.
to
(
torch
.
int32
),
flashinfer_hidden_states
=
(
token_final_scales
=
topk_weights
.
float
(),
(
hidden_states
,
a1q_scale
)
w1_weight
=
w1
,
if
envs
.
VLLM_DEEPEPLL_NVFP4_DISPATCH
w1_weight_sf
=
self
.
w1_scale
,
else
hidden_states
w1_alpha
=
self
.
g1_alphas
,
)
fc2_input_scale
=
self
.
a2_gscale
,
flashinfer_cutedsl_moe_masked
(
w2_weight
=
w2
,
hidden_states
=
flashinfer_hidden_states
,
w2_weight_sf
=
self
.
w2_scale
,
input_global_scale
=
input_global_scale
,
w2_alpha
=
self
.
g2_alphas
,
w1
=
w1
,
num_experts
=
self
.
global_num_experts
,
w1_blockscale
=
self
.
w1_scale
,
top_k
=
self
.
topk
,
w1_alpha
=
self
.
g1_alphas
,
num_local_experts
=
self
.
local_num_experts
,
w2
=
w2
,
local_expert_offset
=
self
.
local_expert_offset
,
a2_global_scale
=
self
.
a2_gscale
,
moe_output
=
output
,
w2_blockscale
=
self
.
w2_scale
,
)
w2_alpha
=
self
.
g2_alphas
,
masked_m
=
expert_num_tokens
,
workspace
=
workspace2
,
out
=
output
,
)
def
get_cute_dtype
(
input
:
torch
.
Tensor
)
->
str
:
if
input
.
dtype
==
torch
.
bfloat16
:
return
"bfloat16"
elif
input
.
dtype
==
torch
.
float16
:
return
"float16"
elif
input
.
dtype
==
torch
.
float32
:
return
"float32"
else
:
raise
ValueError
(
f
"Unsupported cute dtype
{
input
.
dtype
}
"
)
def
flashinfer_cutedsl_moe_masked
(
hidden_states
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
input_global_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alpha
,
w2
:
torch
.
Tensor
,
a2_global_scale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alpha
,
masked_m
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert
w1
.
dtype
==
torch
.
uint8
,
f
"w1 must be uint8, got
{
w1
.
dtype
}
"
assert
w1_blockscale
.
dtype
==
torch
.
float8_e4m3fn
,
(
f
"w1_blockscale must be float8_e4m3fn, got
{
w1_blockscale
.
dtype
}
"
)
assert
w1_alpha
.
dtype
==
torch
.
float32
,
(
f
"w1_alpha must be float32, got
{
w1_alpha
.
dtype
}
"
)
assert
w2
.
dtype
==
torch
.
uint8
,
f
"w2 must be uint8, got
{
w2
.
dtype
}
"
assert
a2_global_scale
.
dtype
==
torch
.
float32
,
(
f
"a2_global_scale must be float32, got
{
a2_global_scale
.
dtype
}
"
)
assert
w2_blockscale
.
dtype
==
torch
.
float8_e4m3fn
,
(
f
"w2_blockscale must be float8_e4m3fn, got
{
w2_blockscale
.
dtype
}
"
)
assert
w2_alpha
.
dtype
==
torch
.
float32
,
(
f
"w2_alpha must be float32, got
{
w2_alpha
.
dtype
}
"
)
# === Assertions on shapes ===
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
if
isinstance
(
hidden_states
,
tuple
):
assert
input_global_scale
is
None
,
(
"input_global_scale is needed when input needs quant"
)
aq
=
hidden_states
[
0
].
view
(
torch
.
uint8
)
aq_sf
=
hidden_states
[
1
].
view
(
torch
.
float8_e4m3fn
)
# m, k_by_2, num_experts = aq.shape
num_experts
,
m
,
k_by_2
=
aq
.
shape
k
=
k_by_2
*
2
aq
=
aq
.
permute
(
1
,
2
,
0
)
else
:
num_experts
,
m
,
k
=
hidden_states
.
shape
assert
input_global_scale
.
dtype
==
torch
.
float32
,
(
f
"input_global_scale must be float32, got
{
input_global_scale
.
dtype
}
"
)
assert
input_global_scale
.
shape
==
(
num_experts
,),
(
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
)
aq
,
aq_sf
=
scaled_fp4_grouped_quantize
(
hidden_states
,
masked_m
,
input_global_scale
,
)
assert
w1
.
shape
[
-
2
]
==
2
*
n
,
f
"w1 last-2 dim must be 2*n, got
{
w1
.
shape
}
"
assert
w1
.
shape
[
-
1
]
*
2
==
k
,
(
f
"w1 last dim * 2 must equal k, got
{
w1
.
shape
[
-
1
]
}
vs k=
{
k
}
"
)
assert
w2
.
shape
[
-
2
:]
==
(
k
,
n
//
2
,
),
f
"w2 shape mismatch, got
{
w2
.
shape
[
-
2
:]
}
, expected
{
(
k
,
n
//
2
)
}
"
assert
w1_alpha
.
shape
==
(
num_experts
,),
(
f
"w1_alpha must be (l,), got
{
w1_alpha
.
shape
}
"
)
assert
a2_global_scale
.
shape
==
(
num_experts
,),
(
f
"a2_global_scale must be (l,), got
{
a2_global_scale
.
shape
}
"
)
assert
w2_alpha
.
shape
==
(
num_experts
,),
(
f
"w2_alpha must be (l,), got
{
w2_alpha
.
shape
}
"
)
workspace
=
workspace
.
permute
(
1
,
2
,
0
)
# requirement of kernel
sf_vec_size
=
16
assert
aq_sf
.
dtype
==
torch
.
float8_e4m3fn
assert
aq
.
dtype
==
torch
.
uint8
ab_dtype
=
"float4_e2m1fn"
sf_dtype
=
"float8_e4m3fn"
if
isinstance
(
hidden_states
,
tuple
):
c_dtype
=
"bfloat16"
else
:
c_dtype
=
get_cute_dtype
(
hidden_states
)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked
(
(
aq
,
aq_sf
),
(
w1
.
permute
(
1
,
2
,
0
),
w1_blockscale
),
workspace
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w1_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w1_alpha
),
)
# in logical [m, n, l]
# SILU and quantization
diq
,
diq_sf
=
silu_and_mul_scaled_nvfp4_experts_quantize
(
workspace
.
permute
(
2
,
0
,
1
),
masked_m
,
a2_global_scale
,
)
# Gemm2
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked
(
(
diq
,
diq_sf
),
(
w2
.
permute
(
1
,
2
,
0
),
w2_blockscale
),
out
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w2_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w2_alpha
),
)
# in logical [m, k, l]
out
=
out
.
permute
(
2
,
0
,
1
)
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
View file @
678b3c99
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
)
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
prepare_nvfp4_moe_layer_for_fi_or_cutlass
,
prepare_nvfp4_moe_layer_for_fi_or_cutlass
,
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl
,
)
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
FlashinferMoeBackend
,
...
@@ -38,6 +39,7 @@ class NvFp4MoeBackend(Enum):
...
@@ -38,6 +39,7 @@ class NvFp4MoeBackend(Enum):
FLASHINFER_TRTLLM
=
"FLASHINFER_TRTLLM"
FLASHINFER_TRTLLM
=
"FLASHINFER_TRTLLM"
FLASHINFER_CUTLASS
=
"FLASHINFER_CUTLASS"
FLASHINFER_CUTLASS
=
"FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL
=
"FLASHINFER_CUTEDSL"
FLASHINFER_CUTEDSL
=
"FLASHINFER_CUTEDSL"
FLASHINFER_CUTEDSL_BATCHED
=
"FLASHINFER_CUTEDSL_BATCHED"
VLLM_CUTLASS
=
"VLLM_CUTLASS"
VLLM_CUTLASS
=
"VLLM_CUTLASS"
MARLIN
=
"MARLIN"
MARLIN
=
"MARLIN"
...
@@ -46,6 +48,7 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [
...
@@ -46,6 +48,7 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL_BATCHED
,
]
]
fi_2_vllm_backend_map
:
dict
[
FlashinferMoeBackend
,
NvFp4MoeBackend
]
=
{
fi_2_vllm_backend_map
:
dict
[
FlashinferMoeBackend
,
NvFp4MoeBackend
]
=
{
...
@@ -92,6 +95,13 @@ def backend_to_kernel_cls(
...
@@ -92,6 +95,13 @@ def backend_to_kernel_cls(
return
[
FlashInferCuteDSLExperts
]
return
[
FlashInferCuteDSLExperts
]
elif
backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL_BATCHED
:
from
vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe
import
(
# noqa: E501
FlashInferCuteDSLBatchedExperts
,
)
return
[
FlashInferCuteDSLBatchedExperts
]
elif
backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
:
elif
backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsFp4
,
CutlassExpertsFp4
,
...
@@ -140,6 +150,7 @@ def select_nvfp4_moe_backend(
...
@@ -140,6 +150,7 @@ def select_nvfp4_moe_backend(
AVAILABLE_BACKENDS
=
[
AVAILABLE_BACKENDS
=
[
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL_BATCHED
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
VLLM_CUTLASS
,
NvFp4MoeBackend
.
VLLM_CUTLASS
,
NvFp4MoeBackend
.
MARLIN
,
NvFp4MoeBackend
.
MARLIN
,
...
@@ -195,6 +206,12 @@ def select_nvfp4_moe_backend(
...
@@ -195,6 +206,12 @@ def select_nvfp4_moe_backend(
runner_backend
=
config
.
moe_backend
runner_backend
=
config
.
moe_backend
if
runner_backend
!=
"auto"
:
if
runner_backend
!=
"auto"
:
requested_backend
=
map_nvfp4_backend
(
runner_backend
)
requested_backend
=
map_nvfp4_backend
(
runner_backend
)
# For batched activation format, use batched variant if available.
if
(
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
and
requested_backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
):
requested_backend
=
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL_BATCHED
return
_return_or_raise
(
return
_return_or_raise
(
requested_backend
,
config
,
weight_key
,
activation_key
,
activation_format
requested_backend
,
config
,
weight_key
,
activation_key
,
activation_format
)
)
...
@@ -285,7 +302,28 @@ def convert_to_nvfp4_moe_kernel_format(
...
@@ -285,7 +302,28 @@ def convert_to_nvfp4_moe_kernel_format(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
]:
if
(
if
nvfp4_backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
:
(
w13
,
w13_scale
,
w13_scale_2
,
a13_scale
,
w2
,
w2_scale
,
w2_scale_2
,
a2_scale
,
)
=
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl
(
layer
=
layer
,
w13
=
w13
,
w13_scale
=
w13_scale
,
w13_scale_2
=
w13_scale_2
,
a13_scale
=
a13_scale
,
w2
=
w2
,
w2_scale
=
w2_scale
,
w2_scale_2
=
w2_scale_2
,
a2_scale
=
a2_scale
,
)
elif
(
nvfp4_backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
nvfp4_backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
or
nvfp4_backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
or
nvfp4_backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
):
):
...
@@ -377,7 +415,13 @@ def make_nvfp4_moe_quant_config(
...
@@ -377,7 +415,13 @@ def make_nvfp4_moe_quant_config(
# NOTE(rob): this is a hack until the MoE kernels
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
# does not accept swizzled input quant scales.
is_nvfp4_scale_swizzled
=
(
backend
!=
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
),
is_nvfp4_scale_swizzled
=
(
backend
not
in
(
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
)
),
)
)
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
678b3c99
...
@@ -60,6 +60,100 @@ def reorder_w1w3_to_w3w1(
...
@@ -60,6 +60,100 @@ def reorder_w1w3_to_w3w1(
)
)
def
interleave_linear_and_gate
(
x
:
torch
.
Tensor
,
group_size
:
int
=
64
,
dim
:
int
=
-
1
,
)
->
torch
.
Tensor
:
"""Interleave gate and linear weight rows for CuteDSL wrapper."""
sizes
=
x
.
size
()
dim
=
dim
%
x
.
dim
()
assert
sizes
[
dim
]
%
(
group_size
*
2
)
==
0
,
(
f
"dim
{
dim
}
size
{
sizes
[
dim
]
}
must be divisible by
{
group_size
*
2
}
"
)
prev_sizes
=
sizes
[:
dim
]
post_sizes
=
sizes
[
dim
+
1
:]
x
=
x
.
view
(
*
prev_sizes
,
2
,
sizes
[
dim
]
//
(
group_size
*
2
),
group_size
,
*
post_sizes
)
x
=
x
.
transpose
(
dim
,
dim
+
1
).
contiguous
().
view
(
*
sizes
)
return
x
def
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl
(
layer
:
"FusedMoE"
,
w13
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w13_scale_2
:
torch
.
Tensor
,
a13_scale
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale_2
:
torch
.
Tensor
,
a2_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
"""Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend.
Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper,
and interleaves w13 gate/linear rows.
"""
from
flashinfer.cute_dsl.utils
import
convert_sf_to_mma_layout
# Global scaling factors (same as other FlashInfer backends).
num_experts
=
w13
.
shape
[
0
]
a13_scale
=
a13_scale
.
max
().
to
(
torch
.
float32
).
expand
(
num_experts
)
a2_scale
=
a2_scale
.
max
().
to
(
torch
.
float32
).
expand
(
num_experts
)
half
=
w13
.
shape
[
1
]
//
2
w13
=
torch
.
cat
([
w13
[:,
half
:],
w13
[:,
:
half
]],
dim
=
1
)
w13_scale
=
torch
.
cat
([
w13_scale
[:,
half
:],
w13_scale
[:,
:
half
]],
dim
=
1
)
# Interleave up/gate rows for w13 weights and scales.
w13
=
interleave_linear_and_gate
(
w13
,
group_size
=
64
,
dim
=
1
)
w13_scale
=
interleave_linear_and_gate
(
w13_scale
,
group_size
=
64
,
dim
=
1
)
# Convert w13 scale factors: linear → swizzled → MMA layout.
w13_scale
=
swizzle_blockscale
(
w13_scale
)
E
,
M_padded
,
K_sf_padded
=
w13_scale
.
shape
w13_scale_flat
=
w13_scale
.
reshape
(
E
*
M_padded
,
K_sf_padded
)
w13_scale
=
convert_sf_to_mma_layout
(
w13_scale_flat
,
m
=
M_padded
,
k
=
K_sf_padded
*
16
,
num_groups
=
E
,
sf_vec_size
=
16
,
)
# Convert w2 scale factors: linear → swizzled → MMA layout.
w2_scale
=
swizzle_blockscale
(
w2_scale
)
E
,
M_padded
,
K_sf_padded
=
w2_scale
.
shape
w2_scale_flat
=
w2_scale
.
reshape
(
E
*
M_padded
,
K_sf_padded
)
w2_scale
=
convert_sf_to_mma_layout
(
w2_scale_flat
,
m
=
M_padded
,
k
=
K_sf_padded
*
16
,
num_groups
=
E
,
sf_vec_size
=
16
,
)
return
(
w13
,
w13_scale
,
w13_scale_2
,
a13_scale
,
w2
,
w2_scale
,
w2_scale_2
,
a2_scale
,
)
def
prepare_static_weights_for_trtllm_fp4_moe
(
def
prepare_static_weights_for_trtllm_fp4_moe
(
# args_dequant,
# args_dequant,
# args,
# args,
...
@@ -221,7 +315,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
...
@@ -221,7 +315,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
NvFp4MoeBackend
.
VLLM_CUTLASS
,
NvFp4MoeBackend
.
VLLM_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
_BATCHED
,
]
]
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
...
...
vllm/utils/flashinfer.py
View file @
678b3c99
...
@@ -128,6 +128,12 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
...
@@ -128,6 +128,12 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
nvfp4_block_scale_interleave
=
_lazy_import_wrapper
(
nvfp4_block_scale_interleave
=
_lazy_import_wrapper
(
"flashinfer.fp4_quantization"
,
"block_scale_interleave"
"flashinfer.fp4_quantization"
,
"block_scale_interleave"
)
)
flashinfer_cute_dsl_fused_moe_nvfp4
=
_lazy_import_wrapper
(
"flashinfer"
,
"cute_dsl_fused_moe_nvfp4"
)
flashinfer_convert_sf_to_mma_layout
=
_lazy_import_wrapper
(
"flashinfer.cute_dsl.utils"
,
"convert_sf_to_mma_layout"
)
trtllm_fp4_block_scale_moe
=
_lazy_import_wrapper
(
trtllm_fp4_block_scale_moe
=
_lazy_import_wrapper
(
"flashinfer"
,
"trtllm_fp4_block_scale_moe"
"flashinfer"
,
"trtllm_fp4_block_scale_moe"
)
)
...
@@ -251,6 +257,15 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
...
@@ -251,6 +257,15 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
return
True
return
True
@
functools
.
cache
def
has_flashinfer_cutedsl_moe_nvfp4
()
->
bool
:
"""Return ``True`` if FlashInfer cute_dsl_fused_moe_nvfp4 is available."""
if
not
has_flashinfer_cutedsl
():
return
False
mod
=
_get_submodule
(
"flashinfer"
)
return
mod
is
not
None
and
hasattr
(
mod
,
"cute_dsl_fused_moe_nvfp4"
)
@
functools
.
cache
@
functools
.
cache
def
has_nvidia_artifactory
()
->
bool
:
def
has_nvidia_artifactory
()
->
bool
:
"""Return `True` if NVIDIA's artifactory is accessible.
"""Return `True` if NVIDIA's artifactory is accessible.
...
@@ -767,6 +782,8 @@ __all__ = [
...
@@ -767,6 +782,8 @@ __all__ = [
"silu_and_mul_scaled_nvfp4_experts_quantize"
,
"silu_and_mul_scaled_nvfp4_experts_quantize"
,
"scaled_fp4_grouped_quantize"
,
"scaled_fp4_grouped_quantize"
,
"nvfp4_block_scale_interleave"
,
"nvfp4_block_scale_interleave"
,
"flashinfer_cute_dsl_fused_moe_nvfp4"
,
"flashinfer_convert_sf_to_mma_layout"
,
"trtllm_fp4_block_scale_moe"
,
"trtllm_fp4_block_scale_moe"
,
"autotune"
,
"autotune"
,
"has_flashinfer_moe"
,
"has_flashinfer_moe"
,
...
@@ -775,6 +792,7 @@ __all__ = [
...
@@ -775,6 +792,7 @@ __all__ = [
"has_flashinfer_nvlink_one_sided"
,
"has_flashinfer_nvlink_one_sided"
,
"has_flashinfer_cutlass_fused_moe"
,
"has_flashinfer_cutlass_fused_moe"
,
"has_flashinfer_cutedsl_grouped_gemm_nt_masked"
,
"has_flashinfer_cutedsl_grouped_gemm_nt_masked"
,
"has_flashinfer_cutedsl_moe_nvfp4"
,
"has_flashinfer_fp8_blockscale_gemm"
,
"has_flashinfer_fp8_blockscale_gemm"
,
"has_nvidia_artifactory"
,
"has_nvidia_artifactory"
,
"supports_trtllm_attention"
,
"supports_trtllm_attention"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment