Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3e0e933
Unverified
Commit
c3e0e933
authored
Jul 31, 2025
by
Wentao Ye
Committed by
GitHub
Jul 31, 2025
Browse files
[Feature] Add Flashinfer MoE Support for Compressed Tensor NVFP4 (#21639)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
6e672daf
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
287 additions
and
129 deletions
+287
-129
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+50
-3
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+24
-126
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+154
-0
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
...l_executor/layers/quantization/utils/nvfp4_moe_support.py
+59
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
c3e0e933
...
...
@@ -17,9 +17,14 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa
FlashInferCutlassMoEPrepareAndFinalize
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
WNA16_SUPPORTED_BITS
,
WNA16_SUPPORTED_TYPES_MAP
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_kernel
,
flashinfer_fp4_cutlass_moe_forward
,
reorder_w1w3_to_w3w1
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
)
...
...
@@ -28,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_moe_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
,
swizzle_blockscale
)
swizzle_blockscale
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
@@ -96,8 +101,14 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class
CompressedTensorsW4A4MoeMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
):
self
.
use_marlin
=
not
cutlass_fp4_supported
()
from
vllm.model_executor.layers.quantization.utils.nvfp4_moe_support
import
(
# noqa: E501
detect_nvfp4_moe_support
)
_nvfp4
=
detect_nvfp4_moe_support
(
self
.
__class__
.
__name__
)
self
.
cutlass_nvfp4_supported
=
_nvfp4
.
cutlass_supported
self
.
allow_flashinfer_cutlass
=
_nvfp4
.
allow_flashinfer_cutlass
self
.
use_marlin
=
_nvfp4
.
use_marlin
self
.
group_size
=
16
self
.
fused_experts
=
None
# type: ignore[assignment]
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
...
...
@@ -200,6 +211,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight_packed
.
data
,
requires_grad
=
False
)
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
if
self
.
allow_flashinfer_cutlass
:
w
,
s
=
reorder_w1w3_to_w3w1
(
layer
.
w13_weight
.
data
,
layer
.
w13_weight_scale
.
data
,
dim
=-
2
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
s
,
requires_grad
=
False
)
if
not
torch
.
allclose
(
layer
.
w13_weight_global_scale
[:,
0
],
layer
.
w13_weight_global_scale
[:,
1
]):
logger
.
warning_once
(
...
...
@@ -246,6 +265,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale_quant
=
torch
.
nn
.
Parameter
(
(
layer
.
w2_input_global_scale
),
requires_grad
=
False
)
def
maybe_swap_experts_impl
(
self
,
moe_parallel_config
):
if
not
self
.
allow_flashinfer_cutlass
:
return
self
.
fused_experts
=
build_flashinfer_fp4_cutlass_moe_kernel
(
moe_parallel_config
)
def
select_gemm_impl
(
self
,
prepare_finalize
,
moe
):
"""Return the appropriate GEMM experts implementation."""
assert
moe
is
not
None
and
prepare_finalize
is
not
None
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
# noqa: E501
select_nvfp4_gemm_impl
)
return
select_nvfp4_gemm_impl
(
self
.
allow_flashinfer_cutlass
,
moe
,
logger
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -303,10 +337,23 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
# FlashInfer fused experts path
if
self
.
fused_experts
is
not
None
:
return
flashinfer_fp4_cutlass_moe_forward
(
self
.
fused_experts
,
layer
,
x
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
assert
expert_map
is
None
,
(
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod."
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp4
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
c3e0e933
...
...
@@ -10,11 +10,8 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.distributed
import
get_ep_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa: E501
FlashInferCutlassMoEPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
...
...
@@ -23,6 +20,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_kernel
,
flashinfer_fp4_cutlass_moe_forward
,
reorder_w1w3_to_w3w1
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_flashinfer_per_tensor_scale_fp8
,
rotate_flashinfer_fp8_moe_weights
,
swap_w13_to_w31
)
...
...
@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.flashinfer
import
has_flashinfer_moe
...
...
@@ -869,28 +868,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
self
.
use_marlin
=
False
self
.
allow_flashinfer_cutlass
=
False
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
:
if
self
.
cutlass_nvfp4_supported
and
current_platform
.
is_cuda
()
\
and
current_platform
.
is_device_capability
(
100
):
logger
.
info_once
(
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE."
)
self
.
allow_flashinfer_cutlass
=
True
else
:
logger
.
warning_once
(
"Flashinfer CUTLASS Fused MoE not supported "
"or found on the current platform."
)
if
not
self
.
cutlass_nvfp4_supported
:
if
is_fp4_marlin_supported
():
self
.
use_marlin
=
True
else
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_moe_support
import
(
# noqa: E501
detect_nvfp4_moe_support
)
_nvfp4
=
detect_nvfp4_moe_support
(
self
.
__class__
.
__name__
)
self
.
cutlass_nvfp4_supported
=
_nvfp4
.
cutlass_supported
self
.
allow_flashinfer_cutlass
=
_nvfp4
.
allow_flashinfer_cutlass
self
.
use_marlin
=
_nvfp4
.
use_marlin
self
.
fused_experts
=
None
# type: ignore
...
...
@@ -900,29 +883,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
):
if
not
self
.
allow_flashinfer_cutlass
:
return
logger
.
debug_once
(
"FlashInferExperts"
)
# default to TP/EP case only
experts_kwargs
:
dict
[
str
,
Any
]
=
{
"use_nvfp4_w4a4"
:
True
,
"use_dp"
:
moe_parallel_config
.
dp_size
>
1
,
"ep_rank"
:
moe_parallel_config
.
ep_rank
,
"ep_size"
:
moe_parallel_config
.
ep_size
,
"tp_rank"
:
moe_parallel_config
.
tp_rank
,
"tp_size"
:
moe_parallel_config
.
tp_size
,
}
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
FlashInferExperts
)
experts
=
FlashInferExperts
(
**
experts_kwargs
)
self
.
fused_experts
=
mk
.
FusedMoEModularKernel
(
FlashInferCutlassMoEPrepareAndFinalize
(
quant_dtype
=
torch
.
uint8
,
#meaning 2x e2m1 packed in one, kernel requirement
),
experts
,
)
self
.
fused_experts
=
build_flashinfer_fp4_cutlass_moe_kernel
(
moe_parallel_config
)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
...
...
@@ -931,32 +893,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def
select_gemm_impl
(
self
,
prepare_finalize
,
moe
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
assert
moe
is
not
None
assert
prepare_finalize
is
not
None
experts
=
None
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
if
self
.
allow_flashinfer_cutlass
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
FlashInferExperts
)
logger
.
debug_once
(
"Using FlashInferExperts"
)
experts
=
FlashInferExperts
(
use_nvfp4_w4a4
=
True
,
use_dp
=
moe
.
moe_parallel_config
.
dp_size
>
1
,
ep_rank
=
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
moe
.
moe_parallel_config
.
ep_size
,
tp_rank
=
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
moe
.
moe_parallel_config
.
tp_size
,
)
else
:
assert
moe
.
dp_size
>
1
logger
.
debug_once
(
"Using CutlassExpertsFp4"
)
# Currently CutlassExpertsFp4 doesn't support DP
raise
ValueError
(
"CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE backend instead "
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
assert
moe
is
not
None
and
prepare_finalize
is
not
None
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
# noqa: E501
select_nvfp4_gemm_impl
)
return
experts
return
select_nvfp4_gemm_impl
(
self
.
allow_flashinfer_cutlass
,
moe
,
logger
)
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
...
...
@@ -1062,18 +1004,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight_scale
=
layer
.
w13_weight_scale
.
data
if
self
.
allow_flashinfer_cutlass
:
dim
=
-
2
size
=
gemm1_weight
.
size
(
dim
)
assert
size
%
2
==
0
,
f
"Expected even size in dim
{
dim
}
, got
{
size
}
"
half
=
size
//
2
# Reorder weight
w1
,
w3
=
gemm1_weight
.
split
(
half
,
dim
=
dim
)
gemm1_weight
=
torch
.
cat
([
w3
,
w1
],
dim
=
dim
).
contiguous
()
# Reorder scale
s1
,
s3
=
gemm1_weight_scale
.
split
(
half
,
dim
=
dim
)
gemm1_weight_scale
=
torch
.
cat
([
s3
,
s1
],
dim
=
dim
).
contiguous
()
gemm1_weight
,
gemm1_weight_scale
=
reorder_w1w3_to_w3w1
(
gemm1_weight
,
gemm1_weight_scale
,
dim
=-
2
)
layer
.
w13_weight
=
Parameter
(
gemm1_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
gemm1_weight_scale
,
...
...
@@ -1217,49 +1149,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
)
else
:
# TP or DP case
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
is_valid_flashinfer_cutlass_fused_moe
)
assert
is_valid_flashinfer_cutlass_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
),
(
"Flashinfer CUTLASS Fused MoE not applicable!"
)
a1_gscale
=
layer
.
w13_input_scale_quant
a2_gscale
=
layer
.
w2_input_scale_quant
extra_expert_args
=
{
'g1_alphas'
:
layer
.
g1_alphas
,
'g2_alphas'
:
layer
.
g2_alphas
,
'out_dtype'
:
x
.
dtype
,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
'a1_gscale'
:
a1_gscale
,
'a2_gscale'
:
a2_gscale
,
}
extra_prepare_args
=
{
'use_dp'
:
layer
.
dp_size
>
1
,
'local_tokens'
:
x
.
shape
[
0
],
'a1_gscale'
:
a1_gscale
,
}
extra_finalize_args
=
{
'use_dp'
:
layer
.
dp_size
>
1
,
'local_tokens'
:
x
.
shape
[
0
],
}
out
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
# TODO(shuw): fix later, now output is high prec
out
=
flashinfer_fp4_cutlass_moe_forward
(
self
.
fused_experts
,
layer
,
x
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_blockscale_swizzled
,
w2_scale
=
layer
.
w2_blockscale_swizzled
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
extra_expert_args
=
extra_expert_args
,
extra_prepare_args
=
extra_prepare_args
,
extra_finalize_args
=
extra_finalize_args
,
)
return
out
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
0 → 100644
View file @
c3e0e933
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
is_valid_flashinfer_cutlass_fused_moe
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa: E501
FlashInferCutlassMoEPrepareAndFinalize
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
__all__
=
[
"is_flashinfer_fp4_cutlass_moe_available"
,
"reorder_w1w3_to_w3w1"
,
"build_flashinfer_fp4_cutlass_moe_kernel"
,
"flashinfer_fp4_cutlass_moe_forward"
,
]
def
is_flashinfer_fp4_cutlass_moe_available
()
->
bool
:
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
))
def
reorder_w1w3_to_w3w1
(
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
dim
:
int
=
-
2
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
size
=
weight
.
size
(
dim
)
assert
size
%
2
==
0
,
f
"Expected even size in dim
{
dim
}
, got
{
size
}
"
half
=
size
//
2
w1
,
w3
=
weight
.
split
(
half
,
dim
=
dim
)
s1
,
s3
=
scale
.
split
(
half
,
dim
=
dim
)
return
(
torch
.
cat
([
w3
,
w1
],
dim
=
dim
).
contiguous
(),
torch
.
cat
([
s3
,
s1
],
dim
=
dim
).
contiguous
())
def
build_flashinfer_fp4_cutlass_moe_kernel
(
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
mk
.
FusedMoEModularKernel
:
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
experts
=
FlashInferExperts
(
use_nvfp4_w4a4
=
True
,
use_dp
=
moe_parallel_config
.
dp_size
>
1
,
ep_rank
=
moe_parallel_config
.
ep_rank
,
ep_size
=
moe_parallel_config
.
ep_size
,
tp_rank
=
moe_parallel_config
.
tp_rank
,
tp_size
=
moe_parallel_config
.
tp_size
,
)
logger
.
debug_once
(
"FlashInferExperts (util)"
)
return
mk
.
FusedMoEModularKernel
(
FlashInferCutlassMoEPrepareAndFinalize
(
quant_dtype
=
torch
.
uint8
),
experts
,
)
def
flashinfer_fp4_cutlass_moe_forward
(
fused_experts
:
mk
.
FusedMoEModularKernel
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
)
->
torch
.
Tensor
:
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
assert
is_valid_flashinfer_cutlass_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
),
(
"FlashInfer CUTLASS fused-MoE not applicable!"
)
a1_gscale
=
layer
.
w13_input_scale_quant
a2_gscale
=
layer
.
w2_input_scale_quant
extra_expert_args
=
{
"g1_alphas"
:
layer
.
g1_alphas
,
"g2_alphas"
:
layer
.
g2_alphas
,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
"a1_gscale"
:
a1_gscale
,
"a2_gscale"
:
a2_gscale
,
"out_dtype"
:
x
.
dtype
,
}
extra_prepare_args
=
{
"use_dp"
:
layer
.
dp_size
>
1
,
"local_tokens"
:
x
.
shape
[
0
],
"a1_gscale"
:
a1_gscale
,
}
extra_finalize_args
=
{
"use_dp"
:
layer
.
dp_size
>
1
,
"local_tokens"
:
x
.
shape
[
0
],
}
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
# TODO(shuw): fix later, now output is high prec
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_blockscale_swizzled
,
w2_scale
=
layer
.
w2_blockscale_swizzled
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
extra_expert_args
=
extra_expert_args
,
extra_prepare_args
=
extra_prepare_args
,
extra_finalize_args
=
extra_finalize_args
,
)
def
select_nvfp4_gemm_impl
(
allow_flashinfer_cutlass
:
bool
,
moe
,
# FusedMoEConfig
logger
):
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
# lazy import
from
vllm.distributed
import
get_ep_group
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
if
allow_flashinfer_cutlass
:
logger
.
debug_once
(
"Using FlashInferExperts"
)
return
FlashInferExperts
(
use_nvfp4_w4a4
=
True
,
use_dp
=
moe
.
moe_parallel_config
.
dp_size
>
1
,
ep_rank
=
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
moe
.
moe_parallel_config
.
ep_size
,
tp_rank
=
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
moe
.
moe_parallel_config
.
tp_size
,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise
ValueError
(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
0 → 100644
View file @
c3e0e933
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
is_flashinfer_fp4_cutlass_moe_available
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
is_fp4_marlin_supported
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
)
__all__
=
[
"detect_nvfp4_moe_support"
,
"NvFp4Support"
]
_logger
=
init_logger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
NvFp4Support
:
"""Result container for NV-FP4 capability probing."""
cutlass_supported
:
bool
allow_flashinfer_cutlass
:
bool
use_marlin
:
bool
def
detect_nvfp4_moe_support
(
class_name
:
str
=
""
)
->
NvFp4Support
:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported
=
cutlass_fp4_supported
()
allow_flashinfer
=
(
cutlass_supported
and
is_flashinfer_fp4_cutlass_moe_available
())
if
allow_flashinfer
:
_logger
.
info_once
(
"Using FlashInfer kernels for %s."
,
class_name
or
"NVFP4 path"
)
else
:
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
:
_logger
.
warning_once
(
"FlashInfer kernels unavailable for %s on current platform."
,
class_name
or
"NVFP4 path"
,
)
use_marlin
=
False
if
not
cutlass_supported
:
if
is_fp4_marlin_supported
():
use_marlin
=
True
_logger
.
info_once
(
"Falling back to Marlin FP4 MoE kernel."
)
else
:
raise
ValueError
(
"Current platform does not support NVFP4 quantization. "
"Please use Blackwell GPUs or enable FlashInfer."
)
return
NvFp4Support
(
cutlass_supported
=
cutlass_supported
,
allow_flashinfer_cutlass
=
allow_flashinfer
,
use_marlin
=
use_marlin
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment