Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
......@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale),
a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape)
......@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
......
......@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
......
......@@ -7,6 +7,8 @@ from typing import Optional
import torch
from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -312,14 +314,11 @@ def rocm_aiter_fused_experts(
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU)
......@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
......@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2,
topk_weights,
topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_scale=quant_config.w1_scale,
fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
......@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\
not supported for block scaled moe")
assert w1_scale is not None
assert w2_scale is not None
assert quant_config.w1_scale is not None
assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8:
elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value
......@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask,
quant_method=quant_method,
activation_method=activation_method,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input)
......
......@@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
......@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
super().__init__(quant_config)
self.triton_expert = TritonExperts(quant_config)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
self.quant_config) if self.allow_deep_gemm else None
@property
def activation_formats(
......@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
......
......@@ -268,3 +268,7 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
......@@ -9,8 +9,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
......@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment