Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4 repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale), a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant, per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape) block_shape=quant_config.block_shape)
...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
a1_scale,
a2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
num_experts, num_experts,
......
...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None return a1q, a1q_scale, None, None, None
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
import torch import torch
from vllm import envs from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -312,14 +314,11 @@ def rocm_aiter_fused_experts( ...@@ -312,14 +314,11 @@ def rocm_aiter_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, expert_map: Optional[torch.Tensor] = None,
per_channel_quant: bool = False, quant_config: Optional[FusedMoEQuantConfig] = None,
w1_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor:
w2_scale: Optional[torch.Tensor] = None, if quant_config is None:
a1_scale: Optional[torch.Tensor] = None, quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
activation_method = (ActivationMethod.SILU activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU) if activation == "silu" else ActivationMethod.GELU)
...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts( ...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None expert_mask = None
# w8a8 per-channel quantization # w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer # This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC. # rather than the second FC.
...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts( ...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
fc1_scale=w1_scale, fc1_scale=quant_config.w1_scale,
fc2_scale=w2_scale, fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None, fc1_smooth_scale=None,
fc2_smooth_scale=None, fc2_smooth_scale=None,
a16=False, a16=False,
...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts( ...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value quant_method = QuantMethod.NO.value
# w8a8 block-scaled # w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8: if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\ "apply_router_weight_on_input is\
not supported for block scaled moe") not supported for block scaled moe")
assert w1_scale is not None assert quant_config.w1_scale is not None
assert w2_scale is not None assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8: elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled. # Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value quant_method = QuantMethod.PER_TENSOR.value
...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts( ...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask, expert_mask=expert_mask,
quant_method=quant_method, quant_method=quant_method,
activation_method=activation_method, activation_method=activation_method,
w1_scale=w1_scale, w1_scale=quant_config.w1_scale,
w2_scale=w2_scale, w2_scale=quant_config.w2_scale,
a1_scale=a1_scale, a1_scale=quant_config.a1_scale,
a2_scale=a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input) doweight_stage1=apply_router_weight_on_input)
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape) deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
...@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
use_fp8_w8a8: bool = False, quant_config: FusedMoEQuantConfig,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
): ):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8, self.triton_expert = TritonExperts(quant_config)
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape()) self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts( self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None self.quant_config) if self.allow_deep_gemm else None
@property @property
def activation_formats( def activation_formats(
...@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, activation,
global_num_experts, global_num_experts,
expert_map, expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale, a1q_scale,
a2_scale,
workspace13, workspace13,
workspace2, workspace2,
expert_tokens_meta, expert_tokens_meta,
......
...@@ -268,3 +268,7 @@ def _validate_scale_shape( ...@@ -268,3 +268,7 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
...@@ -9,8 +9,10 @@ from torch.nn import Parameter ...@@ -9,8 +9,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None: if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment