Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
......@@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
......@@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
......@@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
......@@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: FusedMoEQuantConfig,
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
"""
quant_config: Quantization parameters for this experts instance.
"""
self.quant_config = quant_config
@property
@abstractmethod
......@@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
#
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
......@@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
@property
def a1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_scale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_gscale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_scale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_zp
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_bias
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_bias
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g1_alphas
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
......@@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
......@@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
......@@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
used for a1. Result of quantization from prepare/finalize and not
from the FusedMoEQuantConfig.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
......@@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
......@@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
......@@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
......@@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
return (
a1q[s:e],
_chunk_scales(a1q_scale, s, e),
_chunk_scales(self.fused_experts.a2_scale, s, e),
topk_ids[s:e],
topk_weights[s:e],
)
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
......@@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
......@@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
......@@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
......@@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
......@@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......
......@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale),
a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape)
......@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
......
......@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
......
......@@ -7,6 +7,8 @@ from typing import Optional
import torch
from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
def rocm_aiter_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU)
......@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
......@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2,
topk_weights,
topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_scale=quant_config.w1_scale,
fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
......@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\
not supported for block scaled moe")
assert w1_scale is not None
assert w2_scale is not None
assert quant_config.w1_scale is not None
assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8:
elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value
......@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask,
quant_method=quant_method,
activation_method=activation_method,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input)
......
......@@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
......@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
super().__init__(quant_config)
self.triton_expert = TritonExperts(quant_config)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
self.quant_config) if self.allow_deep_gemm else None
@property
def activation_formats(
......@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
......
......@@ -5,7 +5,8 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2
......@@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size,
):
super().__init__(moe.quant_config)
super().__init__(quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size
@property
......@@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)
assert w1_scale is not None
assert w2_scale is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
......@@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
self.w1_scale,
"gemm1_bias":
self.w13_bias,
self.w1_bias,
"gemm1_alpha":
self.gemm1_alpha,
"gemm1_beta":
......@@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
self.w2_scale,
"gemm2_bias":
self.w2_bias,
"output1_scale_scalar":
......
......@@ -268,3 +268,7 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
......@@ -9,8 +9,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
......@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......
......@@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
......@@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......@@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
def _create_weights_4bit(
......
......@@ -16,8 +16,11 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
......@@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
......@@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
......@@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return super().maybe_make_prepare_finalize(moe)
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin:
return None
elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize()
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return nvfp4_moe_quant_config(
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
......@@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
......@@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map=expert_map,
workspace=layer.workspace)
# FlashInfer fused experts path
if self.fused_experts is not None:
elif self.fused_experts is not None:
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
......@@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# FlashInfer fused experts path
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
......@@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input,
)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
assert self.moe_quant_config is not None
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
......@@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
self.fused_experts_func = None
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
if self.use_cutlass:
device = layer.w13_weight.device
......@@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
device=device,
dtype=torch.int64)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
......@@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
moe.num_local_experts,
self.moe.num_local_experts,
num_dispatchers,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
self.disable_expert_map = (num_dispatchers > 1
......@@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
)
assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
quant_config=self.moe_quant_config,
)
else:
return TritonExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
)
logger.debug("TritonExperts(%s)", self.__class__.__name__)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant,
)
def apply(
self,
......@@ -841,16 +861,74 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
#
# Note: the order here is important. self.fused_experts can override
# cutlass fp8 or fused_experts but not marlin or rocm.
#
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace)
elif self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
assert self.fused_experts is None
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
elif self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
)
# cutlass path
if self.use_cutlass:
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
elif self.use_cutlass:
assert self.moe_quant_config is not None
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
......@@ -860,110 +938,48 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func(
else:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace)
assert self.fused_experts_func is not None
return self.fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
quant_config=self.moe_quant_config,
)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
......@@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
quant_config=self.moe_quant_config,
)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......@@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.workspace = marlin_make_workspace_new(device, 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......@@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4
else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def apply(
self,
layer: torch.nn.Module,
......@@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])
quant_config=self.moe_quant_config,
)
......@@ -8,6 +8,8 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None)
def apply(
self,
layer: torch.nn.Module,
......@@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
quant_config=self.moe_quant_config,
)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
......
......@@ -14,9 +14,11 @@ from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.rocm_aiter_moe_enabled or self.use_marlin
or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
......@@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
......@@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=False,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
self.moe,
self.moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
......@@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and self.fused_experts is None):
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
......@@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_replica_count=logical_replica_count,
)
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
assert self.fused_experts is None
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
......@@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
use_fp8_w8a8=True,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
expert_map=expert_map,
quant_config=self.moe_quant_config)
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
......@@ -1105,6 +1121,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace)
elif self.fused_experts:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None)
......@@ -1112,33 +1141,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
common_kwargs = dict(
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -1149,26 +1166,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
if self.fused_experts is not None:
return self.fused_experts(**common_kwargs)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
**common_kwargs,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm),
)
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod):
......
......@@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
......@@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......
......@@ -9,8 +9,10 @@ import torch
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
......@@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......
......@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
use_prepack=True,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
......
......@@ -11,7 +11,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import (
......@@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
......@@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.fused_experts is not None or \
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
self.moe,
self.moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
......@@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=False,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert self.fused_experts is None
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert not renormalize
......@@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
#
# Note: the order here is important. self.fused_experts can override
# cutlass or fused_experts.
#
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
assert self.moe_quant_config is not None
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class ModelOptNvFp4Config(QuantizationConfig):
......@@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.CUTLASS):
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.use_marlin
or (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM)):
return None
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = (
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
))
build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(moe)
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
......@@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if (self.use_marlin or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
return nvfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
assert self.fused_experts is None
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4,
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
......@@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
# trtllm.
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
......@@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
workspace=layer.workspace)
if self.fused_experts is not None:
elif self.fused_experts is not None:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
......@@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
assert self.moe_quant_config is not None
out = flashinfer_cutlass_moe_fp4(
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
inplace=False, # TODO(shuw): fix later, now output is high prec
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
......@@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
out = cutlass_moe_fp4(
assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
# TODO: derive from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return out
)
......@@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
......@@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
)
def apply(
self,
layer: torch.nn.Module,
......@@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(
x,
layer.w13_qweight,
......@@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size])
quant_config=self.moe_quant_config,
)
@staticmethod
def get_weight_loader(layer, weight_loader):
......
......@@ -12,6 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
......@@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return None
if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = layer.w13_precision_config
w2_scale = layer.w2_precision_config
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
......@@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
"w13_bias": layer.w13_bias,
"w2_bias": layer.w2_bias,
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(moe, **kwargs)
assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
......@@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
......
......@@ -11,6 +11,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
......@@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=self.weight_qscheme == "per_channel",
)
def apply(
self,
layer: torch.nn.Module,
......@@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
quant_config=self.moe_quant_config,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
......@@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
quant_config=self.moe_quant_config)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
......@@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return mxfp4_w4a4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
def apply(
self,
layer: torch.nn.Module,
......@@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_mxfp4_w4a4=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
quant_config=self.moe_quant_config,
)
return out
......@@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply(
self,
layer: torch.nn.Module,
......@@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
ret = fused_experts(
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
block_shape=[0, group_size])
return ret
quant_config=self.moe_quant_config,
)
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment