Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts) FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
...@@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, ...@@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
def build_flashinfer_fp4_cutlass_moe_prepare_finalize( def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig, moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
a1_gscale: torch.Tensor,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
def select_nvfp4_gemm_impl( def select_nvfp4_gemm_impl(
moe: FusedMoEConfig, moe: FusedMoEConfig,
g1_alphas: torch.Tensor, moe_quant_config: FusedMoEQuantConfig,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
allow_flashinfer: bool, allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer: if allow_flashinfer:
return FlashInferExperts( return FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype, out_dtype=moe.in_dtype,
quant_dtype="nvfp4", quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank, ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size, ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank, tp_rank=moe.moe_parallel_config.tp_rank,
......
...@@ -8,7 +8,8 @@ import torch ...@@ -8,7 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts) FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
...@@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8( ...@@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, ( assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized") "Expected output1_scales_scalar to be initialized")
assert layer.output1_scales_scalar is not None, ( assert layer.output1_scales_scalar is not None, (
...@@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: ...@@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
def build_flashinfer_fp8_cutlass_moe_prepare_finalize( def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig], moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
layer: torch.nn.Module,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize( return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
use_dp, a1_gscale=layer.w13_input_scale)
def select_cutlass_fp8_gemm_impl( def select_cutlass_fp8_gemm_impl(
moe: Optional[FusedMoEConfig], moe: Optional[FusedMoEConfig],
layer: torch.nn.Module, quant_config: FusedMoEQuantConfig,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers""" """Return a GEMM *experts* implementation for fused-MoE layers"""
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
if moe is not None: if moe is not None:
return FlashInferExperts( return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=moe.in_dtype, out_dtype=moe.in_dtype,
quant_dtype=torch.float8_e4m3fn, quant_config=quant_config,
ep_rank=moe.moe_parallel_config.ep_rank, ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size, ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank, tp_rank=moe.moe_parallel_config.tp_rank,
...@@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl( ...@@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
assert out_dtype is not None, ( assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed") "If moe config is None, out_dtype must be passed")
return FlashInferExperts( return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
quant_dtype=torch.float8_e4m3fn, quant_config=quant_config,
) )
...@@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8( ...@@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
fused_experts = mk.FusedMoEModularKernel( fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
layer=layer),
select_cutlass_fp8_gemm_impl(moe=None, select_cutlass_fp8_gemm_impl(moe=None,
layer=layer, quant_config=quant_config,
out_dtype=hidden_states.dtype)) out_dtype=hidden_states.dtype))
return fused_experts( return fused_experts(
......
...@@ -411,6 +411,7 @@ def per_token_group_quant_fp8( ...@@ -411,6 +411,7 @@ def per_token_group_quant_fp8(
x_s = torch.empty(shape, device=x.device, dtype=torch.float32) x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available # prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous(): if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0) fp8_min, fp8_max, use_ue8m0)
......
...@@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn, from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn) get_act_fn)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe import (activation_without_mul,
fused_topk, torch_vllm_outplace_fused_experts) fused_topk)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -230,7 +230,7 @@ class NomicMoE(nn.Module): ...@@ -230,7 +230,7 @@ class NomicMoE(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.total_intermediate_size = intermediate_size self.total_intermediate_size = intermediate_size
self.intermediate_size = divide(intermediate_size, self.tp_size) self.intermediate_size = divide(intermediate_size, self.tp_size)
self.hidden_act = hidden_act self.hidden_act = activation_without_mul(hidden_act)
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -297,14 +297,14 @@ class NomicMoE(nn.Module): ...@@ -297,14 +297,14 @@ class NomicMoE(nn.Module):
router_logits, router_logits,
self.top_k, self.top_k,
renormalize=False) renormalize=False)
final_hidden_states = torch_vllm_outplace_fused_experts(
final_hidden_states = torch.ops.vllm.outplace_fused_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=self.w1, w1=self.w1,
w2=self.w2, w2=self.w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=self.hidden_act, activation=self.hidden_act,
is_act_and_mul=False,
) )
if self.tp_size > 1: if self.tp_size > 1:
......
...@@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, ...@@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module): ...@@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1, topk_weights, topk_ids, _ = fused_topk(
self.w2, hidden_states,
router_logits, router_logits,
self.top_k, self.top_k,
renormalize=self.config.norm_topk_prob, renormalize=self.config.norm_topk_prob)
inplace=True)
final_hidden_states = fused_experts(hidden_states,
self.w1,
self.w2,
topk_weights,
topk_ids,
inplace=True)
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
......
...@@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, ...@@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module): ...@@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws, topk_weights, topk_ids, _ = fused_topk(hidden_states,
self.w2s, router_logits,
router_logits, self.top_k,
self.top_k, renormalize=True)
renormalize=True,
inplace=True) final_hidden_states = fused_experts(hidden_states,
self.ws,
self.w2s,
topk_weights,
topk_ids,
inplace=True)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, ...@@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
\ No newline at end of file
...@@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not (isinstance(module, FusedMoE) if not isinstance(module, FusedMoE):
and module.moe_config.quant_dtype == torch.float8_e4m3fn return False
and module.moe_config.block_shape == deep_gemm_block_shape()):
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
if (moe_quant_config is None
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
or moe_quant_config.block_shape != deep_gemm_block_shape()):
return False return False
if not isinstance(module.quant_method.fused_experts, if not isinstance(module.quant_method.fused_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment