"vscode:/vscode.git/clone" did not exist on "88dbf92cfb32bf3670a02f5c58e296ff39e40cd9"
Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4 repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale), a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant, per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape) block_shape=quant_config.block_shape)
...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
a1_scale,
a2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
num_experts, num_experts,
......
...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None return a1q, a1q_scale, None, None, None
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
import torch import torch
from vllm import envs from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk( ...@@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
def rocm_aiter_fused_experts( def rocm_aiter_fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, expert_map: Optional[torch.Tensor] = None,
per_channel_quant: bool = False, quant_config: Optional[FusedMoEQuantConfig] = None,
w1_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor:
w2_scale: Optional[torch.Tensor] = None, if quant_config is None:
a1_scale: Optional[torch.Tensor] = None, quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
activation_method = (ActivationMethod.SILU activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU) if activation == "silu" else ActivationMethod.GELU)
...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts( ...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None expert_mask = None
# w8a8 per-channel quantization # w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer # This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC. # rather than the second FC.
...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts( ...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
fc1_scale=w1_scale, fc1_scale=quant_config.w1_scale,
fc2_scale=w2_scale, fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None, fc1_smooth_scale=None,
fc2_smooth_scale=None, fc2_smooth_scale=None,
a16=False, a16=False,
...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts( ...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value quant_method = QuantMethod.NO.value
# w8a8 block-scaled # w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8: if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\ "apply_router_weight_on_input is\
not supported for block scaled moe") not supported for block scaled moe")
assert w1_scale is not None assert quant_config.w1_scale is not None
assert w2_scale is not None assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8: elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled. # Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value quant_method = QuantMethod.PER_TENSOR.value
...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts( ...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask, expert_mask=expert_mask,
quant_method=quant_method, quant_method=quant_method,
activation_method=activation_method, activation_method=activation_method,
w1_scale=w1_scale, w1_scale=quant_config.w1_scale,
w2_scale=w2_scale, w2_scale=quant_config.w2_scale,
a1_scale=a1_scale, a1_scale=quant_config.a1_scale,
a2_scale=a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input) doweight_stage1=apply_router_weight_on_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment