Unverified Commit 8ad7285e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra...


[Kernels] Clean up FusedMoeMethodBase and modular kernel setup.  Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 48b01fd4
...@@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig): ...@@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
} }
awq_marlin_config = AWQMarlinConfig.from_config( awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict) marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config) return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None return None
......
...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa ...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels.") "Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config( return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
return AWQMoEMethod(self) return AWQMoEMethod(self, layer.moe_config)
return None return None
@classmethod @classmethod
...@@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
class AWQMoEMethod(FusedMoEMethodBase): class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig): def __init__(
self,
quant_config: AWQMarlinConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.weight_bits != 4: if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.") raise ValueError("AWQMoEMethod only supports 4bit now.")
...@@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.") "EPLB not supported for `AWQMoEMethod` yet.")
...@@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from packaging import version from packaging import version
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self) return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self) return BitsAndBytesMoEMethod(self, layer.moe_config)
return None return None
...@@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
quant_config: The BitsAndBytes quantization config. quant_config: The BitsAndBytes quantization config.
""" """
def __init__(self, quant_config: BitsAndBytesConfig): def __init__(
self,
quant_config: BitsAndBytesConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
try: try:
import bitsandbytes import bitsandbytes
if version.parse( if version.parse(
...@@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
raise ImportError("Please install bitsandbytes>=0.46.1 via " raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use " "`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
...@@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering, ...@@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy) QuantizationStrategy)
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferCutlassMoEPrepareAndFinalize) is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new, check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales) marlin_moe_permute_scales)
...@@ -58,6 +59,9 @@ __all__ = [ ...@@ -58,6 +59,9 @@ __all__ = [
class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
...@@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic." "WNA16MoE is not supported with actorder=group/dynamic."
) )
logger.info_once("Using CompressedTensorsWNA16MoEMethod") logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config,
layer.moe_config)
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod() return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)): or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config) return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
...@@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self): def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.fused_experts = None # type: ignore[assignment] self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale_quant = torch.nn.Parameter( layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False) (layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config): def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer: if not self.allow_flashinfer:
return return super().maybe_make_prepare_finalize(moe)
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
def select_gemm_impl(self, prepare_finalize, moe): prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
"""Return the appropriate GEMM experts implementation.""" moe,
assert moe is not None and prepare_finalize is not None a1_gscale=self.layer.w13_input_scale_quant,
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 )
select_nvfp4_gemm_impl) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def apply( def apply(
self, self,
...@@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for " raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.") "`CompressedTensorsW4A4MoeMethod` yet.")
...@@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.use_marlin: if self.use_marlin:
...@@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
# FlashInfer fused experts path # FlashInfer fused experts path
if self.fused_experts is not None: if self.fused_experts is not None:
return flashinfer_fp4_cutlass_moe_forward( assert is_valid_flashinfer_cutlass_fused_moe(
self.fused_experts, x, layer.w13_weight, layer.w2_weight), (
layer, "Flashinfer CUTLASS Fused MoE not applicable!")
x,
topk_weights, return self.fused_experts(
topk_ids, hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to( apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype) x.dtype)
...@@ -385,14 +420,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -385,14 +420,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
self.topk_indices_dtype = None
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy and self.input_quant.strategy
...@@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.weight_quant, self.input_quant) self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -614,24 +649,30 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -614,24 +649,30 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path # cutlass path
if self.use_cutlass: if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
use_batched_format = (prepare_finalize.activation_format == experts: FusedMoEPermuteExpertsUnpermute
FusedMoEActivationFormat.BatchedExperts)
num_dispatchers = prepare_finalize.num_dispatchers() num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
moe.num_local_experts,
num_dispatchers,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8( experts = CutlassExpertsFp8(
num_experts,
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
) )
self.disable_expert_map = (num_dispatchers > 1 self.disable_expert_map = (num_dispatchers > 1
...@@ -835,8 +876,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -835,8 +876,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights") "weights")
...@@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
...@@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
...@@ -976,8 +1022,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -976,8 +1022,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
...@@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
...@@ -1280,8 +1331,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1280,8 +1331,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for " raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsWNA16MoEMethod` yet.") "`CompressedTensorsWNA16MoEMethod` yet.")
...@@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,
......
...@@ -6,7 +6,8 @@ from typing import Any, Callable, Optional ...@@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self) return ExpertsInt8MoEMethod(self, layer.moe_config)
return None return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase): class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config): def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet.") "EPLB not supported for `ExpertsInt8MoEMethod` yet.")
...@@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
import torch import torch
...@@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig): ...@@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self) return Fp8MoEMethod(self, layer.moe_config)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
...@@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
super().__init__(moe)
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
...@@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current " "CutlassBlockScaledGroupedGemm not supported on the current "
"platform.") "platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else: elif self.fused_experts is not None:
return self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
......
...@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
elif isinstance(layer, VocabParallelEmbedding): elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self) return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self) return GGUFMoEMethod(self, layer.moe_config)
return None return None
...@@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
quant_config: The GGUF quantization config. quant_config: The GGUF quantization config.
""" """
def __init__(self, quant_config: GGUFConfig): def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ):
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `GGUFMoEMethod` yet.") "EPLB not supported for `GGUFMoEMethod` yet.")
...@@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids, topk_weights, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,
......
...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa ...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
...@@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
class GPTQMarlinMoEMethod(FusedMoEMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization.""" """MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4: if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8 self.quant_type = scalar_types.uint4b8
...@@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.") "EPLB not supported for `GPTQMarlinMoEMethod` yet.")
...@@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -12,7 +12,9 @@ import vllm.envs as envs ...@@ -12,7 +12,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...@@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
...@@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self) return ModelOptFp8MoEMethod(self, layer.moe_config)
return None return None
...@@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config. quant_config: The ModelOpt quantization config.
""" """
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(
self,
quant_config: ModelOptFp8Config,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported) cutlass_fp8_supported)
...@@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.") "EPLB not supported for `ModelOptFp8MoEMethod` yet.")
...@@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
...@@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self) return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None return None
...@@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config quant_config: NVFP4 Quant Config
""" """
def __init__(self, quant_config: ModelOptNvFp4Config) -> None: def __init__(
self.quant_config = quant_config self,
quant_config: ModelOptNvFp4Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
super().__init__(moe)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
...@@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.fused_experts: Optional[ self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment] mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl( def maybe_make_prepare_finalize(
self, self,
moe_parallel_config: FusedMoEParallelConfig, moe: FusedMoEConfig,
): ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer: if not self.allow_flashinfer:
return return super().maybe_make_prepare_finalize(moe)
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
# when it's not called(TP case), we still have 2 kernels to use.
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
""" """
...@@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.use_marlin: if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
...@@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
device=x.device,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else: else:
assert self.allow_flashinfer and \ assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts, assert is_valid_flashinfer_cutlass_fused_moe(
layer, x, layer.w13_weight, layer.w2_weight), (
x, "Flashinfer CUTLASS Fused MoE not applicable!")
topk_weights,
topk_ids, out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
else: else:
raise ValueError("moe_wna16 only support gptq and awq.") raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return MoeWNA16Method(self) return MoeWNA16Method(self, layer.moe_config)
return None return None
...@@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
""" """
def __init__(self, quant_config: MoeWNA16Config): def __init__(
self,
quant_config: MoeWNA16Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.") "EPLB not supported for `MoeWNA16Method` yet.")
...@@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
......
...@@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig): ...@@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase): class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__() super().__init__(moe)
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.use_marlin = self._should_use_marlin() self.use_marlin = self._should_use_marlin()
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE) OCP_MX_BLOCK_SIZE)
...@@ -25,6 +26,9 @@ __all__ = [ ...@@ -25,6 +26,9 @@ __all__ = [
class QuarkMoEMethod(FusedMoEMethodBase): class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
...@@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase): ...@@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
input_config = layer_quant_config.get("input_tensors") input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config): if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config) return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
module.moe_config)
elif quant_config._is_mx_fp4(weight_config, input_config): elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
module.moe_config)
else: else:
raise RuntimeError("Unsupported FusedMoe scheme") raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, def __init__(
Any]): self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config self.weight_quant = weight_config
self.input_quant = input_config self.input_quant = input_config
...@@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
...@@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,
...@@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, def __init__(
Any]): self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config self.weight_quant = weight_config
self.input_quant = input_config self.input_quant = input_config
...@@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
out = fused_experts( out = fused_experts(
x, x,
......
...@@ -10,7 +10,8 @@ import torch.nn.functional as F ...@@ -10,7 +10,8 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig): ...@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return RTNLinearMethod(self) return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self) return RTNMoEMethod(self, layer.moe_config)
return None return None
...@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase): ...@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
class RTNMoEMethod(FusedMoEMethodBase): class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig): def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.") "EPLB not supported for `RTNMoEMethod` yet.")
...@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size group_size = self.quant_config.group_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment