Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
......@@ -7,7 +7,8 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
......@@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
a1_gscale: torch.Tensor,
) -> mk.FusedMoEPrepareAndFinalize:
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
def select_nvfp4_gemm_impl(
moe: FusedMoEConfig,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
moe_quant_config: FusedMoEQuantConfig,
allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
return FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype,
quant_dtype="nvfp4",
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
......
......@@ -8,7 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
......@@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized")
assert layer.output1_scales_scalar is not None, (
......@@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
) -> mk.FusedMoEPrepareAndFinalize:
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp, a1_gscale=layer.w13_input_scale)
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
def select_cutlass_fp8_gemm_impl(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
quant_config: FusedMoEQuantConfig,
out_dtype: Optional[torch.dtype] = None,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
if moe is not None:
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=moe.in_dtype,
quant_dtype=torch.float8_e4m3fn,
quant_config=quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
......@@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed")
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=out_dtype,
quant_dtype=torch.float8_e4m3fn,
quant_config=quant_config,
)
......@@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
layer=layer),
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
select_cutlass_fp8_gemm_impl(moe=None,
layer=layer,
quant_config=quant_config,
out_dtype=hidden_states.dtype))
return fused_experts(
......
......@@ -411,6 +411,7 @@ def per_token_group_quant_fp8(
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)
......
......@@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, torch_vllm_outplace_fused_experts)
from vllm.model_executor.layers.fused_moe import (activation_without_mul,
fused_topk)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -230,7 +230,7 @@ class NomicMoE(nn.Module):
self.hidden_size = hidden_size
self.total_intermediate_size = intermediate_size
self.intermediate_size = divide(intermediate_size, self.tp_size)
self.hidden_act = hidden_act
self.hidden_act = activation_without_mul(hidden_act)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -297,14 +297,14 @@ class NomicMoE(nn.Module):
router_logits,
self.top_k,
renormalize=False)
final_hidden_states = torch_vllm_outplace_fused_experts(
final_hidden_states = torch.ops.vllm.outplace_fused_experts(
hidden_states=hidden_states,
w1=self.w1,
w2=self.w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=self.hidden_act,
is_act_and_mul=False,
)
if self.tp_size > 1:
......
......@@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
topk_weights, topk_ids, _ = fused_topk(
hidden_states,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob)
final_hidden_states = fused_experts(hidden_states,
self.w1,
self.w2,
topk_weights,
topk_ids,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
......
......@@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
topk_weights, topk_ids, _ = fused_topk(hidden_states,
router_logits,
self.top_k,
renormalize=True)
final_hidden_states = fused_experts(hidden_states,
self.ws,
self.w2s,
topk_weights,
topk_ids,
inplace=True)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
......
......@@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
return self.model.get_expert_mapping()
\ No newline at end of file
......@@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not (isinstance(module, FusedMoE)
and module.moe_config.quant_dtype == torch.float8_e4m3fn
and module.moe_config.block_shape == deep_gemm_block_shape()):
if not isinstance(module, FusedMoE):
return False
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
if (moe_quant_config is None
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
or moe_quant_config.block_shape != deep_gemm_block_shape()):
return False
if not isinstance(module.quant_method.fused_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment