Commit eefa41c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.18.0

parent 82155c76
...@@ -58,10 +58,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -58,10 +58,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
), ),
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb return self.old_quant_method.supports_eplb
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( ...@@ -35,7 +35,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
) )
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# init_aiter_topK_meta_data, # init_aiter_topK_meta_data,
) # )
from vllm.model_executor.layers.fused_moe.router.router_factory import ( from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router, create_fused_moe_router,
) )
...@@ -676,6 +676,10 @@ class FusedMoE(CustomOp): ...@@ -676,6 +676,10 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
return None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
......
...@@ -252,6 +252,7 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): ...@@ -252,6 +252,7 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType: ) -> PrepareResultType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
...@@ -295,6 +296,7 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): ...@@ -295,6 +296,7 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType: ) -> tuple[Callable, ReceiverType] | ReceiverType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
...@@ -1106,6 +1108,7 @@ class FusedMoEKernelModularImpl: ...@@ -1106,6 +1108,7 @@ class FusedMoEKernelModularImpl:
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
...@@ -1118,6 +1121,7 @@ class FusedMoEKernelModularImpl: ...@@ -1118,6 +1121,7 @@ class FusedMoEKernelModularImpl:
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
......
...@@ -3,70 +3,6 @@ ...@@ -3,70 +3,6 @@
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
""" """
Returns a tuple of: Returns a tuple of:
...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
......
...@@ -9,6 +9,9 @@ from vllm import envs ...@@ -9,6 +9,9 @@ from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -31,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -31,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8Static128BlockSym, kFp8Static128BlockSym,
) )
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -179,13 +183,6 @@ def backend_to_kernel_cls( ...@@ -179,13 +183,6 @@ def backend_to_kernel_cls(
return [XPUExpertsFp8] return [XPUExpertsFp8]
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
else: else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
...@@ -527,7 +524,8 @@ def make_fp8_moe_quant_config( ...@@ -527,7 +524,8 @@ def make_fp8_moe_quant_config(
) )
def make_fp8_moe_kernel_for_mkm( def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts], experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend, fp8_backend: Fp8MoeBackend,
...@@ -548,49 +546,15 @@ def make_fp8_moe_kernel_for_mkm( ...@@ -548,49 +546,15 @@ def make_fp8_moe_kernel_for_mkm(
# Create Experts. # Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
......
...@@ -8,6 +8,9 @@ import vllm.envs as envs ...@@ -8,6 +8,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -15,9 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,9 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
) )
...@@ -392,7 +392,8 @@ def make_nvfp4_moe_quant_config( ...@@ -392,7 +392,8 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel_for_mkm( def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts], experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -412,48 +413,15 @@ def make_nvfp4_moe_kernel_for_mkm( ...@@ -412,48 +413,15 @@ def make_nvfp4_moe_kernel_for_mkm(
# Create Experts. # Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape,
moe_kernel_quantize_input,
)
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
block_shape: list[int] | None,
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
# For blocked per token: set to
# cdiv(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
if per_act_token_quant:
# per-token (M x 1)
assert block_shape is None
hidden_scale_bytes = elem_size
elif block_shape is not None:
# per-group (M x K_tiles)
block_size = block_shape[1]
num_blocks = cdiv(hidden_dim, block_size)
hidden_scale_bytes = num_blocks * elem_size
else:
# per-tensor (1 x 1)
hidden_scale_bytes = elem_size
else:
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
hidden_scale_bytes = 0
return (
round_up(hidden_dim_bytes, align),
round_up(hidden_scale_bytes, align),
)
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
a2a: pplx.AllToAll,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
):
super().__init__()
assert max_num_tokens > 0
assert num_local_experts > 0
self.a2a = a2a
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.uint32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert topk_ids.size(0) == num_tokens
# expert_map should be None because with expert map, -1 id is used for
# non-local token; this causes error when casting ids to the
# topk_indices_dtype() int32
#
if expert_map is not None:
logger.warning_once(
"The PPLX backend does not support expert mapping. "
"The provided `expert_map` will be ignored."
)
expert_map = None # noqa: F841
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
_validate_scale_shape(
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
)
orig_a_scale_block_shape: int | None = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
# pplx requires 2-d scales even for scalar scales
if a1q_scale.dim() <= 1:
assert scalar_scales
a1q_scale = a1q_scale.view(1, 1)
orig_a_scale_block_shape = a1q_scale.shape[-1]
if not quant_config.is_block_quantized:
# TODO (bnell): use group_broadcast instead?
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
assert a1q_scale is None or a1q_scale.ndim == 2, (
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
)
expert_num_tokens = torch.empty(
self.num_local_experts,
dtype=torch.int32,
device=device,
)
expert_x = torch.empty(
(
self.num_local_experts,
self.max_num_tokens * self.num_dispatchers(),
hidden_dim,
),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: torch.Tensor | None = None
if a1q.dtype.itemsize == 1:
if quant_config.is_per_act_token:
# (M x 1) -> (E x M x K)
final_dim = expert_x.size(2)
elif quant_config.is_per_tensor:
# (1 x 1) -> (E x 1 x 1)
final_dim = 1
else:
# (M x K_tiles) -> (E x M x K_tiles)
assert quant_config.block_shape is not None
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
final_dim = num_blocks
expert_x_scale_shape = (
self.num_local_experts,
expert_x.size(1),
round_up(final_dim, 4), # round up for alignment
)
expert_x_scale = torch.empty(
expert_x_scale_shape,
dtype=torch.float32,
device=expert_x.device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: torch.Tensor | None = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
return (
hook,
lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
orig_a_scale_block_shape,
),
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: torch.Tensor | None,
orig_a_scale_block_shape: int | None,
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
"Weight application and reduction happens in the combine kernel."
)
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: torch.Tensor | None = None
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
# num_tokens = output.size(0) # M
# assert topk_ids.size(0) == num_tokens, (
# f"{topk_ids.size(0)} == {num_tokens}")
assert topk_ids.size() == topk_weights.size(), (
f"{topk_ids.size()} == {topk_weights.size()}"
)
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}"
)
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None:
super().__init__()
self.defer_input_quant = defer_input_quant
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if self.defer_input_quant:
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
return a1q, a1q_scale, None, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
\ No newline at end of file
...@@ -301,13 +301,6 @@ class AiterExperts(mk.FusedMoEExpertsModular): ...@@ -301,13 +301,6 @@ class AiterExperts(mk.FusedMoEExpertsModular):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled() return rocm_aiter_ops.is_fused_moe_enabled()
......
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -52,7 +51,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -52,7 +51,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -65,7 +63,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ...@@ -65,7 +63,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import
is_flashinfer_mxint4_moe_available, is_flashinfer_mxint4_moe_available,
prepare_static_weights_for_trtllm_mxint4_moe, prepare_static_weights_for_trtllm_mxint4_moe,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe,
...@@ -240,7 +237,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -240,7 +237,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -317,7 +313,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -317,7 +313,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -342,6 +338,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -342,6 +338,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def apply( def apply(
...@@ -384,19 +382,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -384,19 +382,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic, activation_key=None if use_a16 else kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -510,7 +499,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -510,7 +499,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -711,15 +700,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -711,15 +700,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -879,7 +859,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -879,7 +859,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way. # Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
...@@ -941,16 +921,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -941,16 +921,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel( self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
......
...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod ...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -683,15 +682,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -683,15 +682,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -836,7 +826,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -836,7 +826,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -864,16 +854,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -864,16 +854,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel( self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
......
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -39,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -39,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -746,15 +744,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -746,15 +744,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -855,7 +844,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -855,7 +844,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -1213,19 +1202,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1213,19 +1202,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic, activation_key=kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -1357,7 +1337,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1357,7 +1337,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -2229,7 +2209,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase): ...@@ -2229,7 +2209,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
return ModelOptFp8LinearMethod(self.fp8_config) return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4": if quant_algo == "NVFP4":
return ModelOptNvFp4LinearMethod(self.nvfp4_config) return ModelOptNvFp4LinearMethod(self.nvfp4_config)
# Layer not in quantized_layers leave unquantized # Layer not in quantized_layers 鈥?leave unquantized
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
......
...@@ -1134,32 +1134,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1134,32 +1134,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), routing_logits=router_logits.to(torch.bfloat16),
None, # routing_bias routing_bias=None,
x_quant, hidden_states=x_quant,
x_scale, hidden_states_scale=x_scale,
layer.w13_weight, # uint8 (e2m1 x 2) gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2) gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel gemm1_bias=layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert gemm1_beta=layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2) gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0 gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel gemm2_bias=layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar output1_scale_scalar=None,
None, # output1_scale_gate_scalar output1_scale_gate_scalar=None,
None, # output2_scale_scalar output2_scale_scalar=None,
layer.global_num_experts, num_experts=layer.global_num_experts,
layer.top_k, top_k=layer.top_k,
None, # n_group n_group=None,
None, # topk_group topk_group=None,
self.intermediate_size, # padded to multiple of 256 intermediate_size=self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset local_expert_offset=layer.ep_rank * layer.local_num_experts,
self.num_experts, # local num experts local_num_experts=self.num_experts,
None, # routed_scaling_factor routed_scaling_factor=None,
1 if layer.renormalize else 0, # routing_method_type, renormalize routing_method_type=1 if layer.renormalize else 0,
True, # do finalize do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
......
...@@ -30,7 +30,6 @@ logger = init_logger(__name__) ...@@ -30,7 +30,6 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
...@@ -61,17 +60,6 @@ def reorder_w1w3_to_w3w1( ...@@ -61,17 +60,6 @@ def reorder_w1w3_to_w3w1(
) )
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
......
...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING ...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
......
...@@ -431,7 +431,7 @@ class AfmoeModel(nn.Module, EagleModelMixin): ...@@ -431,7 +431,7 @@ class AfmoeModel(nn.Module, EagleModelMixin):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -691,7 +691,7 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): ...@@ -691,7 +691,7 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -542,7 +542,7 @@ class ApertusForCausalLM( ...@@ -542,7 +542,7 @@ class ApertusForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -397,7 +397,7 @@ class ArceeForCausalLM( ...@@ -397,7 +397,7 @@ class ArceeForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -406,7 +406,7 @@ class ArcticModel(nn.Module): ...@@ -406,7 +406,7 @@ class ArcticModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -460,7 +460,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -460,7 +460,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment