Unverified Commit c7d8724e authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)


Signed-off-by: default avatarshuw <shuw@nvidia.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent b38baabc
...@@ -7,9 +7,15 @@ import torch ...@@ -7,9 +7,15 @@ import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import (cutlass_scaled_fp4_mm, from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.distributed import get_ep_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...@@ -713,6 +719,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -713,6 +719,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False self.use_marlin = False
self.allow_flashinfer_cutlass = False
if envs.VLLM_USE_FLASHINFER_MOE:
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
and current_platform.is_device_capability(100):
logger.info_once(
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
self.allow_flashinfer_cutlass = True
else:
logger.warning_once(
"Flashinfer CUTLASS Fused MoE not supported "
"or found on the current platform.")
if not self.cutlass_nvfp4_supported: if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported(): if is_fp4_marlin_supported():
...@@ -722,6 +740,73 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -722,6 +740,73 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and" " quantization. Please use Blackwell and"
" above.") " above.")
self.fused_experts = None # type: ignore
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
if not self.allow_flashinfer_cutlass:
return
logger.debug_once("FlashInferExperts")
# default to TP/EP case only
experts_kwargs: dict[str, Any] = {
"use_nvfp4_w4a4": True,
"use_dp": moe_parallel_config.dp_size > 1,
"ep_rank": moe_parallel_config.ep_rank,
"ep_size": moe_parallel_config.ep_size,
"tp_rank": moe_parallel_config.tp_rank,
"tp_size": moe_parallel_config.tp_size,
}
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
experts = FlashInferExperts(**experts_kwargs)
self.fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=torch.uint8,
#meaning 2x e2m1 packed in one, kernel requirement
),
experts,
)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
# when it's not called(TP case), we still have 2 kernels to use.
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
assert moe is not None
assert prepare_finalize is not None
experts = None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if self.allow_flashinfer_cutlass:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
logger.debug_once("Using FlashInferExperts")
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
else:
assert moe.dp_size > 1
logger.debug_once("Using CutlassExpertsFp4")
# Currently CutlassExpertsFp4 doesn't support DP
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
" backend instead.")
return experts
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
""" """
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
...@@ -842,8 +927,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -842,8 +927,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1 # GEMM 1
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer_cutlass:
dim = -2
size = gemm1_weight.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
# Reorder weight
w1, w3 = gemm1_weight.split(half, dim=dim)
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
# Reorder scale
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
requires_grad=False)
if not torch.allclose(layer.w13_weight_scale_2[:, 0], if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]): layer.w13_weight_scale_2[:, 1]):
logger.warning_once( logger.warning_once(
...@@ -874,9 +981,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -874,9 +981,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_input_scale_quant = Parameter( layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False) (1 / w13_input_scale).to(torch.float32), requires_grad=False)
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
# GEMM 2 # GEMM 2
layer.g2_alphas = Parameter( layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
...@@ -961,31 +1065,74 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -961,31 +1065,74 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
assert expert_map is None, ("Expert Parallelism / expert_map " if self.fused_experts is None:
"is currently not supported for " # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
"ModelOptNvFp4FusedMoE.") # only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4) cutlass_moe_fp4)
out = cutlass_moe_fp4(
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x, a=x,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight, w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas, g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device, device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to( expert_map=expert_map,
x.dtype) apply_router_weight_on_input=apply_router_weight_on_input)
else:
# TP or DP case
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
is_valid_flashinfer_cutlass_fused_moe)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
a1_gscale = torch.min(layer.w13_input_scale_quant)
a2_gscale = torch.min(layer.w2_input_scale_quant)
extra_expert_args = {
'g1_alphas': layer.g1_alphas,
'g2_alphas': layer.g2_alphas,
'out_dtype': x.dtype,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
}
extra_prepare_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
'a1_gscale': a1_gscale,
}
extra_finalize_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
}
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.
Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations
import contextlib
import functools
import importlib
import importlib.util
from typing import Any, Callable, NoReturn
from vllm.logger import init_logger
logger = init_logger(__name__)
@functools.cache
def has_flashinfer() -> bool:
"""Return ``True`` if FlashInfer is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
return importlib.util.find_spec("flashinfer") is not None
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable FlashInfer backend."""
raise RuntimeError(
"FlashInfer backend is not available. Please install the package "
"to enable FlashInfer kernels: "
"https://github.com/flashinfer-ai/flashinfer")
def _get_submodule(module_name: str) -> Any | None:
"""Safely import a submodule and return it, or None if not available."""
try:
return importlib.import_module(module_name)
except (ImportError, ModuleNotFoundError):
return None
# General lazy import wrapper
def _lazy_import_wrapper(module_name: str,
attr_name: str,
fallback_fn: Callable[..., Any] = _missing):
"""Create a lazy import wrapper for a specific function."""
@functools.cache
def _get_impl():
if not has_flashinfer():
return None
mod = _get_submodule(module_name)
return getattr(mod, attr_name, None) if mod else None
def wrapper(*args, **kwargs):
impl = _get_impl()
if impl is None:
return fallback_fn(*args, **kwargs)
return impl(*args, **kwargs)
return wrapper
# Create lazy wrappers for each function
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
"fp4_swizzle_blockscale")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
"flashinfer.autotuner",
"autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
if not has_flashinfer():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "fp4_swizzle_blockscale"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
__all__ = [
"has_flashinfer",
"has_flashinfer_cutlass_fused_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"fp4_swizzle_blockscale",
"autotune",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment