Unverified Commit 6160ba41 authored by Duncan Moss's avatar Duncan Moss Committed by GitHub
Browse files

feat: BF16 FlashInfer Fused Cutlass MOE for Hopper and Blackwell Expert Parallel (#25503)


Signed-off-by: default avatarDuncan Moss <djm.moss@gmail.com>
parent fea80060
...@@ -144,6 +144,7 @@ if TYPE_CHECKING: ...@@ -144,6 +144,7 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput",
...@@ -1145,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1145,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": "VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP16":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))),
# Allow use of FlashInfer MoE kernels for fused moe ops. # Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8": "VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
...@@ -1516,6 +1521,7 @@ def compute_hash() -> str: ...@@ -1516,6 +1521,7 @@ def compute_hash() -> str:
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
......
...@@ -52,8 +52,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -52,8 +52,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
tp_size: int = 1, tp_size: int = 1,
): ):
super().__init__(quant_config) super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( assert quant_config.quant_dtype in (
"Only nvfp4,fp8 quantization are currently supported.") "nvfp4", torch.float8_e4m3fn,
None), ("Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported.")
self.ep_rank = ep_rank self.ep_rank = ep_rank
self.ep_size = ep_size self.ep_size = ep_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
...@@ -109,8 +111,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -109,8 +111,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
aq_m, aq_n = aq.shape aq_m, aq_n = aq.shape
workspace2 = (0, ) workspace2 = (0, )
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ output_shape = (aq_m,
torch.float8_e4m3fn else (aq_m, aq_n) aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
aq_n)
workspace_dtype = a.dtype workspace_dtype = a.dtype
workspace1 = output_shape workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
...@@ -135,6 +138,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -135,6 +138,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool], apply_router_weight_on_input: Optional[bool],
): ):
assert activation == "silu", ("Only activation silu is supported in "
"FlashInferExperts")
if self.quant_dtype == torch.float8_e4m3fn: if self.quant_dtype == torch.float8_e4m3fn:
quant_scales = [ quant_scales = [
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
...@@ -143,7 +150,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -143,7 +150,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q_scale = None # not passing input_sf in fp8 a1q_scale = None # not passing input_sf in fp8
fc1_expert_weights = w1 fc1_expert_weights = w1
fc2_expert_weights = w2 fc2_expert_weights = w2
else: elif self.quant_dtype == "nvfp4":
# Ensure w1_scale and w2_scale are not None before calling view # Ensure w1_scale and w2_scale are not None before calling view
assert self.w1_scale is not None and self.w2_scale is not None, ( assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not " "w1_scale and w2_scale must not "
...@@ -161,6 +168,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -161,6 +168,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4 # FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long) fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long) fc2_expert_weights = w2.view(torch.long)
else:
quant_scales = None
a1q_scale = None
fc1_expert_weights = w1
fc2_expert_weights = w2
_ = flashinfer_cutlass_fused_moe( _ = flashinfer_cutlass_fused_moe(
input=hidden_states, input=hidden_states,
...@@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4( ...@@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4(
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
def flashinfer_cutlass_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
tp_rank: int = 0,
tp_size: int = 1,
ep_rank: int = 0,
ep_size: int = 1,
use_dp: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=use_dp),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
ep_rank=ep_rank,
ep_size=ep_size,
))
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
...@@ -183,7 +183,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize( ...@@ -183,7 +183,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(
dim=0, dim=0,
sizes=get_local_sizes(), sizes=get_local_sizes(),
) )
a1q_scale = nvfp4_block_scale_interleave(a1q_scale) if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
......
...@@ -39,6 +39,7 @@ from vllm.platforms import current_platform ...@@ -39,6 +39,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up) round_up)
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
...@@ -296,6 +297,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -296,6 +297,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
and current_platform.get_device_capability()[0] >= 9)
if self.flashinfer_cutlass_moe_enabled:
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size)
else:
if (self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.")
elif self.moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP."
)
self.flashinfer_cutlass_moe = None # type: ignore
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self) -> Optional[FusedMoEPrepareAndFinalize]: self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
...@@ -367,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -367,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_pad = 256 // weight.element_size() num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache() torch.cuda.empty_cache()
return weight return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -386,6 +422,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -386,6 +422,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = shuffled_w13 layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2 layer.w2_weight.data = shuffled_w2
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu(): if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
...@@ -536,6 +578,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -536,6 +578,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map, expert_map=expert_map,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None: elif self.fused_experts is not None:
if self.moe.has_bias: if self.moe.has_bias:
raise ValueError( raise ValueError(
......
...@@ -598,6 +598,8 @@ class SharedResizableBuffer: ...@@ -598,6 +598,8 @@ class SharedResizableBuffer:
def get(self, shape: tuple[int, ...], device: torch.device, def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype): dtype: torch.dtype):
if shape == () or shape is None:
return None
shape_numel = prod(shape) shape_numel = prod(shape)
if (self.buffer is None or self.buffer.numel() < shape_numel if (self.buffer is None or self.buffer.numel() < shape_numel
or self.buffer.device != device or self.buffer.dtype != dtype): or self.buffer.device != device or self.buffer.dtype != dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment