Unverified Commit 9f6dcb71 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][16/N] Apply Refactor to NVFP4 (#31692)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 8dd2419f
...@@ -11,14 +11,20 @@ import nvtx ...@@ -11,14 +11,20 @@ import nvtx
import torch import torch
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -188,19 +194,24 @@ def bench_run( ...@@ -188,19 +194,24 @@ def bench_run(
g1_alphas=w1_gs, g1_alphas=w1_gs,
g2_alphas=w2_gs, g2_alphas=w2_gs,
) )
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
for _ in range(num_repeats): for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"): with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4( kernel(
a=a, hidden_states=a,
w1_fp4=w1_fp4, w1=w1_fp4,
w2_fp4=w2_fp4, w2=w2_fp4,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
) )
def run_cutlass_from_graph( def run_cutlass_from_graph(
...@@ -230,20 +241,24 @@ def bench_run( ...@@ -230,20 +241,24 @@ def bench_run(
g2_alphas=w2_gs, g2_alphas=w2_gs,
) )
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
return cutlass_moe_fp4( return kernel(
a=a, hidden_states=a,
w1_fp4=w1_fp4, w1=w1_fp4,
w2_fp4=w2_fp4, w2=w2_fp4,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
) )
def run_triton_from_graph( def run_triton_from_graph(
......
...@@ -86,7 +86,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels ...@@ -86,7 +86,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton | standard | all<sup>1</sup> | G,A,T | silu, gelu,</br>swigluoai,</br>silu_no_mul,</br>gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],</br>[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] | | triton | standard | all<sup>1</sup> | G,A,T | silu, gelu,</br>swigluoai,</br>silu_no_mul,</br>gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],</br>[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
| triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] | | triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | | deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],</br>[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | | cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import ( from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX,
...@@ -13,8 +14,13 @@ from tests.kernels.utils import torch_moe ...@@ -13,8 +14,13 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -83,17 +89,21 @@ def test_cutlass_fp4_moe_no_graph( ...@@ -83,17 +89,21 @@ def test_cutlass_fp4_moe_no_graph(
w2_scale=w2_blockscale, w2_scale=w2_blockscale,
) )
cutlass_output = cutlass_moe_fp4( kernel = mk.FusedMoEModularKernel(
a=a, MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
w1_fp4=w1_q, CutlassExpertsFp4(
w2_fp4=w2_q, out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
cutlass_output = kernel(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=quant_config,
m=m,
n=n,
k=k,
e=e,
) )
# Reference check: # Reference check:
......
...@@ -72,7 +72,6 @@ if HAS_TRITON: ...@@ -72,7 +72,6 @@ if HAS_TRITON:
CutlassBatchedExpertsFp8, CutlassBatchedExpertsFp8,
CutlassExpertsFp8, CutlassExpertsFp8,
CutlassExpertsW4A8Fp8, CutlassExpertsW4A8Fp8,
cutlass_moe_fp4,
cutlass_moe_w4a8_fp8, cutlass_moe_w4a8_fp8,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
...@@ -95,7 +94,6 @@ if HAS_TRITON: ...@@ -95,7 +94,6 @@ if HAS_TRITON:
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"GroupedTopk", "GroupedTopk",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8", "cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8", "CutlassExpertsFp8",
"CutlassBatchedExpertsFp8", "CutlassBatchedExpertsFp8",
......
...@@ -336,6 +336,10 @@ class FusedMoEQuantConfig: ...@@ -336,6 +336,10 @@ class FusedMoEQuantConfig:
def use_int4_w4a16(self) -> bool: def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4" return self._a1.dtype is None and self._w1.dtype == "int4"
@property
def use_nvfp4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "nvfp4"
@property @property
def ocp_mx_scheme(self) -> str | None: def ocp_mx_scheme(self) -> str | None:
if not hasattr(self, "_ocp_mx_scheme"): if not hasattr(self, "_ocp_mx_scheme"):
...@@ -690,6 +694,25 @@ def nvfp4_moe_quant_config( ...@@ -690,6 +694,25 @@ def nvfp4_moe_quant_config(
) )
def nvfp4_w4a16_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-but activations and nvp4 weights.
"""
return FusedMoEQuantConfig.make(
quant_dtype=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
weight_dtype="nvfp4",
)
def int4_w4a16_moe_quant_config( def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
......
...@@ -706,68 +706,6 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -706,68 +706,6 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
) )
def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
e: int,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
)
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
max_experts_per_worker=e,
out_dtype=a.dtype,
quant_config=quant_config,
use_batched_format=False,
),
)
return fn(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# W4A8 # W4A8
def run_cutlass_moe_w4a8_fp8( def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor, output: torch.Tensor,
......
...@@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked( ...@@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked(
alpha_dtype=get_cute_dtype(w2_alpha), alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l] ) # in logical [m, k, l]
out = out.permute(2, 0, 1) out = out.permute(2, 0, 1)
def flashinfer_cutedsl_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
FlashInferCuteDSLExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
...@@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize( ...@@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize(
use_deepseek_fp8_block_scale: bool = False, use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP: ) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation.""" """Factory function to create the appropriate FlashInfer implementation."""
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
# once we complete the FP8 refactor.
if use_nvfp4:
if enable_alltoallv:
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# FP8 DP path currently supported via AllGather.
if use_dp: if use_dp:
if enable_alltoallv:
assert use_nvfp4
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
return FlashInferAllGatherMoEPrepareAndFinalize( return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=True, use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
) )
else: else:
# NOTE(rob): CUTLASS FP8 block quant executes the input # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
# quantzation and grouped gemm in a single kernel. # in a single call with the MoE experts kernel.
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale) defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
...@@ -540,9 +540,10 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -540,9 +540,10 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (varun) : Enable activation quantization # TODO (varun) : Enable activation quantization
assert ( assert (
quant_config.use_mxfp4_w4a16 quant_config.use_mxfp4_w4a16
or quant_config.use_nvfp4_w4a16
or quant_config.use_int4_w4a16 or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16 or quant_config.use_fp8_w8a16
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16" ), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
...@@ -555,7 +556,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -555,7 +556,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if self.quant_config.use_int4_w4a16: if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id return scalar_types.uint4b8.id
elif self.quant_config.use_mxfp4_w4a16: elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
return scalar_types.float4_e2m1f.id return scalar_types.float4_e2m1f.id
elif ( elif (
self.quant_config.use_fp8_w8a16 self.quant_config.use_fp8_w8a16
...@@ -692,6 +693,8 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -692,6 +693,8 @@ class MarlinExperts(MarlinExpertsBase):
gating_output=None, gating_output=None,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
global_scale1=self.g1_alphas,
global_scale2=self.g2_alphas,
quant_type_id=self.quant_type_id, quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
......
...@@ -38,9 +38,6 @@ from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimula ...@@ -38,9 +38,6 @@ from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimula
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
...@@ -1125,14 +1122,9 @@ class FusedMoE(CustomOp): ...@@ -1125,14 +1122,9 @@ class FusedMoE(CustomOp):
global_expert_id = expert_id global_expert_id = expert_id
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id) expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)
use_global_sf = ( use_global_sf = (
allow_flashinfer getattr(self.quant_method, "use_global_sf", False)
and is_flashinfer_supporting_global_sf(moe_backend)
and "input_scale" in weight_name and "input_scale" in weight_name
and quant_method_name == "ModelOptNvFp4FusedMoE"
) )
if expert_id == -1 and not use_global_sf: if expert_id == -1 and not use_global_sf:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
VLLM_CUTLASS = "vLLM CUTASS"
MARLIN = "vLLM MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
}
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# Checks whether `backend` supports quantizing with scaling factors
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend):
return f"Using {backend.value} backend for NvFp4 MoE"
if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
allow_flashinfer = (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
else:
backend = NvFp4MoeBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
# Log warning if FI backend requested but not available.
if (
backend not in FLASHINFER_NVFP4_MOE_BACKENDS
and envs.VLLM_USE_FLASHINFER_MOE_FP4
):
logger.warning_once(
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
"Falling back to %s for NvFp4 MoE",
backend.value,
scope="local",
)
else:
logger.info_once(_make_log_backend(backend), scope="local")
return backend
def convert_to_nvfp4_moe_kernel_format(
nvfp4_backend: NvFp4MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor | None,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor | None,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
if (
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
):
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend=nvfp4_backend,
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
a13_scale = None
a2_scale = None
(
w13,
w13_scale,
w13_scale_2,
w2,
w2_scale,
w2_scale_2,
) = prepare_nvfp4_moe_layer_for_marlin(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
return (
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None:
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
if backend in UNSUPPORTED:
return None
elif backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
g1_alphas = a13_scale * w13_scale_2
g2_alphas = a2_scale * w2_scale_2
return nvfp4_moe_quant_config(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=(1.0 / a13_scale),
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
)
def make_nvfp4_moe_kernel(
backend: NvFp4MoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None:
assert moe_config.dp_size == 1
UNSUPPORTED_BACKENDS = [
# TRTLLM does not use the modular kernl abstraction.
NvFp4MoeBackend.FLASHINFER_TRTLLM,
# CUTEDSL is used with BATCHED (masked) format only.
# TODO: add here once we support dp/ep via the oracle.
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
if backend in UNSUPPORTED_BACKENDS:
return None
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
FlashInferExperts(
out_dtype=moe_config.in_dtype,
quant_config=quant_config,
ep_rank=moe_config.ep_rank,
ep_size=moe_config.ep_size,
tp_rank=moe_config.tp_rank,
tp_size=moe_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=moe_config.in_dtype,
# TODO(rob): see what impact this has on expert map?
max_experts_per_worker=moe_config.num_experts,
quant_config=quant_config,
),
)
elif backend == NvFp4MoeBackend.MARLIN:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=quant_config),
)
else:
raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
...@@ -11,7 +11,6 @@ from compressed_tensors.quantization import ( ...@@ -11,7 +11,6 @@ from compressed_tensors.quantization import (
QuantizationArgs, QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
) )
from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -34,12 +33,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -34,12 +33,8 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4afp8_moe_quant_config, int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, int8_w8a16_moe_quant_config,
nvfp4_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts, BatchedMarlinExperts,
MarlinExperts, MarlinExperts,
...@@ -51,6 +46,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -51,6 +46,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_kernel, make_fp8_moe_kernel,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP, WNA16_SUPPORTED_TYPES_MAP,
...@@ -58,14 +62,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress ...@@ -58,14 +62,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
prepare_static_weights_for_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl, select_nvfp4_gemm_impl,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe,
...@@ -77,20 +76,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -77,20 +76,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8, convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace, convert_packed_uint4b8_to_signed_int4_inplace,
swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -218,31 +212,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -218,31 +212,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None): def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 if not moe.is_act_and_mul:
detect_nvfp4_moe_support, raise ValueError(
) "CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
"support non gated MoE models."
)
super().__init__(moe) super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.layer_name = layer_name self.nvfp4_backend = select_nvfp4_moe_backend()
self.marlin_input_dtype = ( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
get_marlin_input_dtype(layer_name) if self.use_marlin else None self.nvfp4_backend
) )
self.flashinfer_moe_backend = None self.kernel: mk.FusedMoEModularKernel | None = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoEMethod."
)
elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
def create_weights( def create_weights(
self, self,
...@@ -355,7 +337,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -355,7 +337,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# From packed to weight """
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
# NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod
# requires this naming convention. However, the name change breaks
# reloading because the state dict no longer matches disk. Once we
# remove MKM, we should revert this change to ensure compatibility.
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -366,144 +354,79 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -366,144 +354,79 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
delattr(layer, "w2_weight_packed") delattr(layer, "w2_weight_packed")
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. # Use a single gscale for w13.
if self.allow_flashinfer: if self.moe.is_act_and_mul and not torch.allclose(
w, s = reorder_w1w3_to_w3w1(
layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2
)
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
if not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
): ):
logger.warning_once( logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. " "w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected." "Accuracy may be affected.",
) )
w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter( # Shuffle weights into the NvFp4 kernel format.
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False (
) w13,
w13_scale,
layer.w2_weight_scale_2 = torch.nn.Parameter( w13_scale_2,
1 / layer.w2_weight_global_scale.data, requires_grad=False a13_scale,
) w2,
w2_scale,
if self.use_marlin: w2_scale_2,
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) a2_scale,
return ) = convert_to_nvfp4_moe_kernel_format(
# w13 nvfp4_backend=self.nvfp4_backend,
if ( layer=layer,
self.allow_flashinfer w13=layer.w13_weight,
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM w13_scale=layer.w13_weight_scale,
): w13_scale_2=(1.0 / w13_weight_global_scale),
w13_input_global_scale = ( a13_scale=(1.0 / layer.w13_input_global_scale),
layer.w13_input_global_scale.min() w2=layer.w2_weight,
.to(torch.float32) w2_scale=layer.w2_weight_scale,
.expand(layer.num_experts) w2_scale_2=(1.0 / layer.w2_weight_global_scale),
) a2_scale=(1.0 / layer.w2_input_global_scale),
else: is_act_and_mul=self.moe.is_act_and_mul,
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
)
layer.w13_input_scale_quant = torch.nn.Parameter(
(w13_input_global_scale), requires_grad=False
)
# w2
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
w2_input_global_scale = (
layer.w2_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_experts)
)
else:
w2_input_global_scale = layer.w2_input_global_scale
layer.g2_alphas = torch.nn.Parameter(
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
layer.w2_input_scale_quant = torch.nn.Parameter(
(w2_input_global_scale), requires_grad=False
) )
# TensorRT-LLM specific processing replace_parameter(layer, "w13_weight", w13)
if ( replace_parameter(layer, "w13_weight_scale", w13_scale)
self.allow_flashinfer replace_parameter(layer, "w2_weight", w2)
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM replace_parameter(layer, "w2_weight_scale", w2_scale)
): layer.w13_weight_scale_2 = w13_scale_2
# Prepare static weights for TRT-LLM kernel layer.w2_weight_scale_2 = w2_scale_2
# alternate: prepare_static_weight_layouts_for_trtllm_moe layer.w13_input_scale = a13_scale
( layer.w2_input_scale = a2_scale
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter( # Initialize the kernel that will be called in apply().
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False self.moe_quant_config = self.get_fused_moe_quant_config(layer)
use_dp = self.moe.dp_size > 1
if self.moe_quant_config is not None and not use_dp:
self.kernel = make_nvfp4_moe_kernel(
backend=self.nvfp4_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or ( UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
self.allow_flashinfer if self.nvfp4_backend in UNSUPPORTED:
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None return None
elif not self.allow_flashinfer: elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables) return super().maybe_make_prepare_finalize(routing_tables)
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
...@@ -514,7 +437,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -514,7 +437,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
experts = select_nvfp4_gemm_impl( experts = select_nvfp4_gemm_impl(
self.moe, self.moe,
self.moe_quant_config, self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer, allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
) )
logger.debug_once("Using %s", experts.__class__.__name__) logger.debug_once("Using %s", experts.__class__.__name__)
return experts return experts
...@@ -522,19 +445,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -522,19 +445,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if ( return make_nvfp4_moe_quant_config(
self.use_marlin backend=self.nvfp4_backend,
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM w13_scale=layer.w13_weight_scale,
):
return None
return nvfp4_moe_quant_config(
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
w13_scale_2=layer.w13_weight_scale_2,
w2_scale_2=layer.w2_weight_scale_2,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
) )
def apply( def apply(
...@@ -546,14 +464,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -546,14 +464,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if ( if (
self.allow_flashinfer self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and not layer.enable_eplb
): ):
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
)
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
layer=layer, layer=layer,
x=x, x=x,
...@@ -566,79 +479,41 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -566,79 +479,41 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = layer.select_experts(
hidden_states=x, hidden_states=x_routing,
router_logits=router_logits, router_logits=router_logits,
) )
if self.use_marlin: # EPLB path
return fused_marlin_moe( if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
else:
assert self.kernel is not None
return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_scale1=layer.w13_weight_scale_2, inplace=False,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
# FlashInfer fused experts path
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight
), "Flashinfer CUTLASS Fused MoE not applicable!"
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
......
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path""" """Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -20,12 +23,23 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -20,12 +23,23 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe,
) )
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
)
logger = init_logger(__name__)
__all__ = [ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available", "is_flashinfer_fp4_cutedsl_moe_available",
...@@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe( ...@@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else: else:
# hidden_states is the already quantized # hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x, x,
a1_gscale, layer.a1_gscale,
is_sf_swizzled_layout=False, is_sf_swizzled_layout=False,
) )
...@@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else: else:
# Quantize input to FP4 # Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x, x,
a1_gscale, layer.a1_gscale,
is_sf_swizzled_layout=False, is_sf_swizzled_layout=False,
) )
...@@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe(
)[0] )[0]
return out return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
# Delayed import for circular dependency avoidance.
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
is_global_sf_supported_for_nvfp4_backend,
)
assert backend in [
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
if is_act_and_mul and backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
# For some FI kernels, the input scales are shared by all experts.
if is_global_sf_supported_for_nvfp4_backend(backend):
num_experts = w13.shape[0]
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
else:
a13_scale = a13_scale.max(dim=1).values.to(torch.float32)
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,
w2,
w13_scale,
w2_scale,
w2.size(-2), # hidden_size
w13.size(-2) // 2, # intermediate_size
w13.size(0), # num_experts
)
# We do not need to make this a parameter, because
# it is not used during the weight (re)-loading process.
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
layer.a1_gscale = 1.0 / a13_scale
layer.g1_alphas = a13_scale * w13_scale_2
layer.g2_alphas = a2_scale * w2_scale_2
else:
# Swizzle the block scales for other FI NVFP4 MoE kernels.
w13_scale = swizzle_blockscale(w13_scale)
# Apply padding if needed.
pad_size = w13_scale.size(1) - w13.size(1)
if pad_size > 0:
if is_act_and_mul:
raise NotImplementedError(
"Intermediate size padding for w1 and w3, for %s "
"NvFp4 backend, but this is not currently supported",
backend.value,
)
w13 = torch.nn.functional.pad(w13, (0, 0, 0, pad_size))
w2 = torch.nn.functional.pad(w2, (0, pad_size // 2, 0, 0))
w2_scale = torch.nn.functional.pad(w2_scale, (0, pad_size // 16))
w2_scale = swizzle_blockscale(w2_scale)
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
...@@ -8,6 +8,7 @@ import vllm._custom_ops as ops ...@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, USE_FP32_REDUCE_DEFAULT,
get_marlin_input_dtype,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
...@@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin( ...@@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin(
return return
def prepare_nvfp4_moe_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype = get_marlin_input_dtype(prefix="")
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
GROUP_SIZE = 16
E = layer.num_experts
K = layer.hidden_size
N = layer.intermediate_size_per_partition
device = w13.device
param_dtype = layer.params_dtype
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
assert weight.shape == (E, size_n, size_k // 2)
for i in range(E):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
# WEIGHT SCALES
# Permute scales
def premute_scales(
scales: torch.Tensor, g_scales: torch.Tensor, name: str
) -> tuple[torch.Tensor, torch.Tensor]:
scales = scales.to(param_dtype)
g_scales = g_scales.to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
for i in range(E):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale,
size_k=size_k,
size_n=size_n,
group_size=GROUP_SIZE,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
g_scales = nvfp4_marlin_process_global_scale(g_scales)
return scales, g_scales
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2")
return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2
def prepare_moe_fp4_layer_for_marlin( def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None: ) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment