Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer. - a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids. - topk_ids: The topk ids.
- topk_weights: The topk weights. - topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space. - num_experts: The total number of experts in the global expert space.
...@@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard. space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
- quantized + dispatched a1_scales. - Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors - Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the as big as the number of local experts with the information about the
number of tokens assigned to each local expert. number of tokens assigned to each local expert.
...@@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError raise NotImplementedError
# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC): class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
An abstract base class for the [Permute-Experts-Unpermute] step described An abstract base class for the [Permute-Experts-Unpermute] step described
...@@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__( def __init__(
self, self,
quant_config: Optional[FusedMoEQuantConfig], quant_config: FusedMoEQuantConfig,
): ):
if quant_config is not None: """
self.quant_config = quant_config quant_config: Quantization parameters for this experts instance.
else: """
self.quant_config = FusedMoEQuantConfig() self.quant_config = quant_config
@property @property
@abstractmethod @abstractmethod
...@@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
raise NotImplementedError raise NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
#
@property @property
def quant_dtype(self) -> Optional[torch.dtype]: def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype return self.quant_config.quant_dtype
...@@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def per_out_ch_quant(self) -> bool: def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant return self.quant_config.per_out_ch_quant
@property
def a1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_scale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_gscale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_scale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_zp
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_bias
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_bias
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g1_alphas
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead? # TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod @abstractmethod
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
...@@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
...@@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations - topk_weights: A map of row to expert weights. Some implementations
choose to do weight application. choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id. - topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first - activation (str): The activation function to apply after the first
MoE layer. MoE layer.
...@@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert from the global expert space to the local expert space of the expert
parallel shard. parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1. used for a1. Result of quantization from prepare/finalize and not
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. from the FusedMoEQuantConfig.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm. must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation - workspace2 (torch.Tensor): A scratch tensor used for the activation
...@@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
...@@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
local_num_experts=local_num_experts, local_num_experts=local_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
Optional[torch.Tensor], torch.Tensor, torch.Tensor]: Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M) e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e), return (
_chunk_scales(a2_scale, s, a1q[s:e],
e), topk_ids[s:e], topk_weights[s:e]) _chunk_scales(a1q_scale, s, e),
_chunk_scales(self.fused_experts.a2_scale, s, e),
topk_ids[s:e],
topk_weights[s:e],
)
def slice_output_tensor(chunk_idx: int) -> torch.Tensor: def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, ( assert fused_out.size(0) % M == 0, (
...@@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
local_num_experts=local_num_experts, local_num_experts=local_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale, a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
activation: str = "silu", activation: str = "silu",
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
""" """
...@@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert from the global expert space to the local expert space of the expert
parallel shard. parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is applied directly on the inputs. This is only applicable when topk is
1. 1.
...@@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1,
a1_scale,
a2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts, global_num_experts,
...@@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
dbo_maybe_run_recv_hook() dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async( hook, receiver = self.prepare_finalize.prepare_async(
a1, a1,
a1_scale,
a2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts, global_num_experts,
...@@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
local_num_experts=local_num_experts, local_num_experts=local_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
......
...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4 repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale), a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant, per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape) block_shape=quant_config.block_shape)
...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
a1_scale,
a2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
num_experts, num_experts,
......
...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None return a1q, a1q_scale, None, None, None
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
import torch import torch
from vllm import envs from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk( ...@@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
def rocm_aiter_fused_experts( def rocm_aiter_fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, expert_map: Optional[torch.Tensor] = None,
per_channel_quant: bool = False, quant_config: Optional[FusedMoEQuantConfig] = None,
w1_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor:
w2_scale: Optional[torch.Tensor] = None, if quant_config is None:
a1_scale: Optional[torch.Tensor] = None, quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
activation_method = (ActivationMethod.SILU activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU) if activation == "silu" else ActivationMethod.GELU)
...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts( ...@@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None expert_mask = None
# w8a8 per-channel quantization # w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer # This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC. # rather than the second FC.
...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts( ...@@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
fc1_scale=w1_scale, fc1_scale=quant_config.w1_scale,
fc2_scale=w2_scale, fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None, fc1_smooth_scale=None,
fc2_smooth_scale=None, fc2_smooth_scale=None,
a16=False, a16=False,
...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts( ...@@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value quant_method = QuantMethod.NO.value
# w8a8 block-scaled # w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8: if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\ "apply_router_weight_on_input is\
not supported for block scaled moe") not supported for block scaled moe")
assert w1_scale is not None assert quant_config.w1_scale is not None
assert w2_scale is not None assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8: elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled. # Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value quant_method = QuantMethod.PER_TENSOR.value
...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts( ...@@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask, expert_mask=expert_mask,
quant_method=quant_method, quant_method=quant_method,
activation_method=activation_method, activation_method=activation_method,
w1_scale=w1_scale, w1_scale=quant_config.w1_scale,
w2_scale=w2_scale, w2_scale=quant_config.w2_scale,
a1_scale=a1_scale, a1_scale=quant_config.a1_scale,
a2_scale=a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input) doweight_stage1=apply_router_weight_on_input)
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape) deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
...@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
use_fp8_w8a8: bool = False, quant_config: FusedMoEQuantConfig,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
): ):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8, self.triton_expert = TritonExperts(quant_config)
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape()) self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts( self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None self.quant_config) if self.allow_deep_gemm else None
@property @property
def activation_formats( def activation_formats(
...@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, activation,
global_num_experts, global_num_experts,
expert_map, expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale, a1q_scale,
a2_scale,
workspace13, workspace13,
workspace2, workspace2,
expert_tokens_meta, expert_tokens_meta,
......
...@@ -5,7 +5,8 @@ from typing import Optional ...@@ -5,7 +5,8 @@ from typing import Optional
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP) TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2 from vllm.utils import next_power_of_2
...@@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha, gemm1_alpha,
gemm1_beta, gemm1_beta,
gemm1_clamp_limit, gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size, max_capture_size,
): ):
super().__init__(moe.quant_config) super().__init__(quant_config)
self.moe = moe self.moe = moe
self.gemm1_alpha = gemm1_alpha self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size self.max_capture_size = max_capture_size
@property @property
...@@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16) torch.bfloat16).view(torch.int16)
assert w1_scale is not None assert self.w1_scale is not None
assert w2_scale is not None assert self.w2_scale is not None
kwargs = { kwargs = {
"topk_ids": "topk_ids":
packed_tensor, packed_tensor,
...@@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm1_weights": "gemm1_weights":
w1, w1,
"gemm1_weights_scale": "gemm1_weights_scale":
w1_scale, self.w1_scale,
"gemm1_bias": "gemm1_bias":
self.w13_bias, self.w1_bias,
"gemm1_alpha": "gemm1_alpha":
self.gemm1_alpha, self.gemm1_alpha,
"gemm1_beta": "gemm1_beta":
...@@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm2_weights": "gemm2_weights":
w2, w2,
"gemm2_weights_scale": "gemm2_weights_scale":
w2_scale, self.w2_scale,
"gemm2_bias": "gemm2_bias":
self.w2_bias, self.w2_bias,
"output1_scale_scalar": "output1_scale_scalar":
......
...@@ -268,3 +268,7 @@ def _validate_scale_shape( ...@@ -268,3 +268,7 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
...@@ -9,8 +9,10 @@ from torch.nn import Parameter ...@@ -9,8 +9,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None: if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union ...@@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch import torch
from packaging import version from packaging import version
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
) )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
quant_config=self.moe_quant_config,
) )
def _create_weights_4bit( def _create_weights_4bit(
......
...@@ -16,8 +16,11 @@ from vllm import _custom_ops as ops ...@@ -16,8 +16,11 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported)
FusedMoeWeightScaleSupported) from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe) is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
...@@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod( return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config) quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) return CompressedTensorsW4A4MoeMethod(layer.moe_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)): or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
...@@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
super().__init__(moe) super().__init__(moe)
...@@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False) (layer.w2_input_global_scale), requires_grad=False)
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig, if self.use_marlin:
) -> Optional[mk.FusedMoEPrepareAndFinalize]: return None
if not self.allow_flashinfer: elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize(moe) return super().maybe_make_prepare_finalize()
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe, self.moe)
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
"""Return the appropriate GEMM experts implementation.""" """Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl( experts = select_nvfp4_gemm_impl(
moe, self.moe,
g1_alphas=self.layer.g1_alphas, self.moe_quant_config,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer, allow_flashinfer=self.allow_flashinfer,
) )
logger.debug_once("Using %s", experts.__class__.__name__) logger.debug_once("Using %s", experts.__class__.__name__)
return experts return experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return nvfp4_moe_quant_config(
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for " raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.") "`CompressedTensorsW4A4MoeMethod` yet.")
...@@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
#
if self.use_marlin: if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace) workspace=layer.workspace)
# FlashInfer fused experts path elif self.fused_experts is not None:
if self.fused_experts is not None:
assert is_valid_flashinfer_cutlass_fused_moe( assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), ( x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!") "Flashinfer CUTLASS Fused MoE not applicable!")
...@@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
# FlashInfer fused experts path
elif self.allow_flashinfer: elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4) flashinfer_cutlass_moe_fp4)
...@@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
x, layer.w13_weight, layer.w2_weight), ( x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!") "Flashinfer CUTLASS Fused MoE not applicable!")
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4( return flashinfer_cutlass_moe_fp4(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
else:
assert expert_map is None, ("Expert Parallelism / expert_map " from vllm.model_executor.layers.fused_moe.cutlass_moe import (
"is currently not supported for " cutlass_moe_fp4)
"CompressedTensorsW4A4MoeMethod.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( assert expert_map is None, ("Expert Parallelism / expert_map "
cutlass_moe_fp4) "is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
# Cutlass moe takes in activations in BF16/Half precision assert self.moe_quant_config is not None
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4( # Cutlass moe takes in activations in BF16/Half precision
a=x, # and fp4 quantized weights loaded from the checkpoint
w1_fp4=layer.w13_weight, return cutlass_moe_fp4(
w2_fp4=layer.w2_weight, a=x,
w1_blockscale=layer.w13_weight_scale, w1_fp4=layer.w13_weight,
w2_blockscale=layer.w2_weight_scale, w2_fp4=layer.w2_weight,
g1_alphas=layer.g1_alphas, topk_weights=topk_weights,
g2_alphas=layer.g2_alphas, topk_ids=topk_ids,
a1_gscale=layer.w13_input_scale_quant, quant_config=self.moe_quant_config,
a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input,
topk_weights=topk_weights, # TODO(bnell): derive these from arguments
topk_ids=topk_ids, m=x.shape[0],
m=x.shape[0], n=layer.w2_weight.shape[2] * 2,
n=layer.w2_weight.shape[2] * 2, k=x.shape[1],
k=x.shape[1], e=layer.w13_weight.shape[0],
e=layer.w13_weight.shape[0], ).to(x.dtype)
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...@@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False) requires_grad=False)
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
elif self.use_marlin: elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
self.fused_experts_func = None
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
if self.use_cutlass: if self.use_cutlass:
device = layer.w13_weight.device device = layer.w13_weight.device
...@@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
device=device, device=device,
dtype=torch.int64) dtype=torch.int64)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl( def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize, self,
moe: FusedMoEConfig, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path # cutlass path
assert self.moe_quant_config is not None
if self.use_cutlass: if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8) CutlassBatchedExpertsFp8, CutlassExpertsFp8)
...@@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logger.debug("CutlassBatchedExpertsFp8(%s)", logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__) self.__class__.__name__)
experts = CutlassBatchedExpertsFp8( experts = CutlassBatchedExpertsFp8(
moe.num_local_experts, self.moe.num_local_experts,
num_dispatchers, num_dispatchers,
moe.in_dtype, self.moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2, ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2, c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
) )
else: else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8( experts = CutlassExpertsFp8(
moe.in_dtype, self.moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2, ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2, c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
) )
self.disable_expert_map = (num_dispatchers > 1 self.disable_expert_map = (num_dispatchers > 1
...@@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert not self.rocm_aiter_moe_enabled and not self.use_marlin assert not self.rocm_aiter_moe_enabled and not self.use_marlin
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
if (prepare_finalize.activation_format == if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts): FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
) )
assert max_num_tokens_per_rank is not None assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts( return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True, quant_config=self.moe_quant_config,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
) )
else: else:
return TritonExperts( logger.debug("TritonExperts(%s)", self.__class__.__name__)
use_fp8_w8a8=True, return TritonExperts(self.moe_quant_config)
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=( def get_fused_moe_quant_config(
self.input_quant.strategy == QuantizationStrategy.TOKEN), self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
) if self.use_marlin:
return None
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant,
)
def apply( def apply(
self, self,
...@@ -841,16 +861,74 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -841,16 +861,74 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
#
# Note: the order here is important. self.fused_experts can override
# cutlass fp8 or fused_experts but not marlin or rocm.
#
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace)
elif self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
assert self.fused_experts is None
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
elif self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
)
# cutlass path # cutlass path
if self.use_cutlass: elif self.use_cutlass:
per_act_token = ( assert self.moe_quant_config is not None
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
# small-batch fallback on SM100 # small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -860,110 +938,48 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -860,110 +938,48 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale, )
a1_scale=layer.w13_input_scale, else:
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8) cutlass_moe_fp8)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return cutlass_moe_fp8( return cutlass_moe_fp8(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token=per_act_token, quant_config=self.moe_quant_config,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2, ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2, c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
) )
if self.rocm_aiter_moe_enabled: else:
return self.rocm_aiter_fused_experts_func( from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace) quant_config=self.moe_quant_config,
)
assert self.fused_experts_func is not None
return self.fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...@@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale, )
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...@@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.transpose(1, 2).contiguous(), layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False) requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4
else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale, )
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])
...@@ -8,6 +8,8 @@ import torch ...@@ -8,6 +8,8 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_scale", w2_scale) layer.register_parameter("w2_scale", w2_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_scale) )
@staticmethod @staticmethod
def quantizing_weight_loader(layer, weight_loader): def quantizing_weight_loader(layer, weight_loader):
......
...@@ -14,9 +14,11 @@ from vllm import _custom_ops as ops ...@@ -14,9 +14,11 @@ from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current " "CutlassBlockScaledGroupedGemm not supported on the current "
"platform.") "platform.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv) layer.w2_weight_scale_inv)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.rocm_aiter_moe_enabled or self.use_marlin
or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
...@@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.") "Marlin and ROCm AITER are not supported with all2all yet.")
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format == if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts): FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = ( max_num_tokens_per_rank = (
...@@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return BatchedTritonOrDeepGemmExperts( return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True, quant_config=self.moe_quant_config,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=False,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl( experts = select_cutlass_fp8_gemm_impl(
moe, self.moe,
self.layer, self.moe_quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__) logger.debug_once("Using %s", experts.__class__.__name__)
return experts return experts
...@@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.__class__.__name__, self.quant_config.weight_block_size, self.__class__.__name__, self.quant_config.weight_block_size,
False) False)
return TritonOrDeepGemmExperts( return TritonOrDeepGemmExperts(
use_fp8_w8a8=True, quant_config=self.moe_quant_config,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
) )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert logical_replica_count is not None assert logical_replica_count is not None
assert isinstance(layer, FusedMoE) assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and self.fused_experts is None):
assert activation == 'silu', ( assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}") f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', ( assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}") f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk assert (renormalize and use_grouped_topk
and custom_routing_function is None) and custom_routing_function is None)
...@@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
) )
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts) rocm_aiter_fused_experts)
assert self.fused_experts is None
return rocm_aiter_fused_experts( return rocm_aiter_fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
use_fp8_w8a8=True,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=(layer.w13_weight_scale_inv expert_map=expert_map,
if self.block_quant else layer.w13_weight_scale), quant_config=self.moe_quant_config)
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -1105,6 +1121,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1105,6 +1121,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace) workspace=layer.workspace)
elif self.fused_experts:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None) assert (not renormalize and custom_routing_function is not None)
...@@ -1112,33 +1141,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1112,33 +1141,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"Expected 'silu' activation but got {activation}") f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', ( assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}") f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.fused_experts is not None:
return self.fused_experts( return flashinfer_cutlass_moe_fp8(
x, x,
layer.w13_weight, layer,
layer.w2_weight, topk_weights,
topk_weights, topk_ids,
topk_ids, inplace=False,
inplace=False, activation=activation,
activation=activation, global_num_experts=global_num_experts,
global_num_experts=global_num_experts, expert_map=expert_map,
expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=apply_router_weight_on_input, )
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else: else:
common_kwargs = dict( from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -1149,26 +1166,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1149,26 +1166,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv quant_config=self.moe_quant_config,
if self.block_quant else layer.w13_weight_scale), allow_deep_gemm=self.allow_deep_gemm,
w2_scale=(layer.w2_weight_scale_inv allow_cutlass_block_scaled_grouped_gemm=(
if self.block_quant else layer.w2_weight_scale), self.allow_cutlass_block_scaled_grouped_gemm))
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
if self.fused_experts is not None:
return self.fused_experts(**common_kwargs)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
**common_kwargs,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm),
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
......
...@@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight_type, extra_weight_attrs) set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type) layer.register_parameter("w2_qweight_type", w2_qweight_type)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -9,8 +9,10 @@ import torch ...@@ -9,8 +9,10 @@ import torch
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
...@@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None: if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter ...@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): ...@@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
use_prepack=True, use_prepack=True,
) )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -11,7 +11,9 @@ import vllm.envs as envs ...@@ -11,7 +11,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe) is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
...@@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
cutlass_fp8_supported) cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
...@@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig, # TRT LLM not supported with all2all yet.
) -> Optional[mk.FusedMoEPrepareAndFinalize]: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fused_experts is not None or \ return None
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe) prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( logger.debug_once("%s", prepare_finalize.__class__.__name__)
moe, return prepare_finalize
layer=self.layer, else:
) return super().maybe_make_prepare_finalize()
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_cutlass_fp8_gemm_impl( experts = select_cutlass_fp8_gemm_impl(
moe, self.moe,
self.layer, self.moe_quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__) logger.debug_once("Using %s", experts.__class__.__name__)
return experts return experts
...@@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight) layer.w2_weight)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=False,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"EPLB not supported for `ModelOptFp8MoEMethod` yet.") "EPLB not supported for `ModelOptFp8MoEMethod` yet.")
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert self.fused_experts is None
assert activation == 'silu', ( assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}") f"Expected 'silu' activation but got {activation}")
assert not renormalize assert not renormalize
...@@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: #
# Note: the order here is important. self.fused_experts can override
# cutlass or fused_experts.
#
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize assert not renormalize
assert activation == 'silu', ( assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}") f"Expected 'silu' activation but got {activation}")
if self.fused_experts is not None: return flashinfer_cutlass_moe_fp8(
return self.fused_experts( x,
x, layer,
layer.w13_weight, topk_weights,
layer.w2_weight, topk_ids,
topk_weights, inplace=False,
topk_ids, activation=activation,
inplace=False, global_num_experts=global_num_experts,
activation=activation, expert_map=expert_map,
global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, )
apply_router_weight_on_input=apply_router_weight_on_input, else:
) from vllm.model_executor.layers.fused_moe.fused_moe import (
else: fused_experts)
return flashinfer_cutlass_moe_fp8( assert self.moe_quant_config is not None
x,
layer, return fused_experts(
topk_weights, x,
topk_ids, layer.w13_weight,
inplace=False, layer.w2_weight,
activation=activation, topk_weights=topk_weights,
global_num_experts=global_num_experts, topk_ids=topk_ids,
expert_map=expert_map, inplace=True,
apply_router_weight_on_input=apply_router_weight_on_input, activation=activation,
) quant_config=self.moe_quant_config,
from vllm.model_executor.layers.fused_moe.fused_moe import ( global_num_experts=global_num_experts,
fused_experts) expert_map=expert_map,
return fused_experts( apply_router_weight_on_input=apply_router_weight_on_input,
x, )
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class ModelOptNvFp4Config(QuantizationConfig): class ModelOptNvFp4Config(QuantizationConfig):
...@@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE.") " for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig, if (self.use_marlin
) -> Optional[mk.FusedMoEPrepareAndFinalize]: or (self.allow_flashinfer and self.flashinfer_moe_backend
if (self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM)):
== FlashinferMoeBackend.CUTLASS): return None
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = ( prepare_finalize = (
build_flashinfer_fp4_cutlass_moe_prepare_finalize( build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe))
moe,
a1_gscale=self.layer.w13_input_scale_quant,
))
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
else:
return super().maybe_make_prepare_finalize(moe) return super().maybe_make_prepare_finalize()
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_nvfp4_gemm_impl( experts = select_nvfp4_gemm_impl(
moe, self.moe,
g1_alphas=self.layer.g1_alphas, self.moe_quant_config,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer, allow_flashinfer=self.allow_flashinfer,
) )
logger.debug_once("Using %s", experts.__class__.__name__) logger.debug_once("Using %s", experts.__class__.__name__)
...@@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data, layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False) requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if (self.use_marlin or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
return nvfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if self.allow_flashinfer and \ if (self.allow_flashinfer and self.flashinfer_moe_backend
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: == FlashinferMoeBackend.TENSORRT_LLM):
import flashinfer import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
assert self.fused_experts is None
a1_gscale = layer.w13_input_scale_quant a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, (hidden_states_fp4,
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
...@@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
# trtllm.
#
if self.use_marlin: if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace) workspace=layer.workspace)
if self.fused_experts is not None: elif self.fused_experts is not None:
assert self.allow_flashinfer and \ assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
...@@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x, layer.w13_weight, layer.w2_weight), ( x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!") "Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
elif (self.allow_flashinfer elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4) flashinfer_cutlass_moe_fp4)
assert self.moe_quant_config is not None
out = flashinfer_cutlass_moe_fp4( return flashinfer_cutlass_moe_fp4(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale, inplace=False,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# only (no EP). # only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4) cutlass_moe_fp4)
out = cutlass_moe_fp4( assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x, a=x,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight, w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
# TODO: derive from arguments
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
expert_map=expert_map, )
apply_router_weight_on_input=apply_router_weight_on_input)
return out
...@@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union ...@@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
...@@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param) layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts( return fused_experts(
x, x,
layer.w13_qweight, layer.w13_qweight,
...@@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_scales, quant_config=self.moe_quant_config,
w2_scale=layer.w2_scales, )
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size])
@staticmethod @staticmethod
def get_weight_loader(layer, weight_loader): def get_weight_loader(layer, weight_loader):
......
...@@ -12,6 +12,8 @@ from vllm.logger import init_logger ...@@ -12,6 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim return tile_tokens_dim
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return None
if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = layer.w13_precision_config
w2_scale = layer.w2_precision_config
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format == if (prepare_finalize.activation_format ==
...@@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"gemm1_alpha": layer.gemm1_alpha, "gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta, "gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit, "gemm1_clamp_limit": layer.gemm1_clamp_limit,
"w13_bias": layer.w13_bias, # TODO(bnell): part of quant_config
"w2_bias": layer.w2_bias,
"max_capture_size": self.max_capture_size, "max_capture_size": self.max_capture_size,
} }
return TrtLlmGenExperts(moe, **kwargs) assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
else: else:
# Use matmul_ogs from triton_kernels here! # Use matmul_ogs from triton_kernels here!
raise NotImplementedError( raise NotImplementedError(
...@@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_bias=layer.w13_bias, quant_config=self.moe_quant_config,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
else: else:
......
...@@ -11,6 +11,9 @@ from vllm.logger import init_logger ...@@ -11,6 +11,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled) is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...@@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts self.fused_experts_func = fused_experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=self.weight_qscheme == "per_channel",
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, quant_config=self.moe_quant_config,
per_channel_quant=self.weight_qscheme == "per_channel",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
expert_map=expert_map) expert_map=expert_map)
if self.use_marlin: if self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
...@@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config)
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
...@@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return mxfp4_w4a4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_mxfp4_w4a4=True, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
) )
return out return out
...@@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter ...@@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
fix_weights(layer, "w13_weight", weight_bits == 4) fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4) fix_weights(layer, "w2_weight", weight_bits == 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits return fused_experts(
group_size = self.quant_config.group_size
ret = fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
block_shape=[0, group_size]) quant_config=self.moe_quant_config,
)
return ret
def rtn_quantize(tensor: torch.Tensor, num_bits: int, def rtn_quantize(tensor: torch.Tensor, num_bits: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment