Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -44,7 +44,8 @@ steps: ...@@ -44,7 +44,8 @@ steps:
- vllm/envs.py - vllm/envs.py
- vllm/config - vllm/config
commands: commands:
- pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 2 parallelism: 2
- label: Kernels Mamba Test - label: Kernels Mamba Test
......
...@@ -12,12 +12,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -12,12 +12,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -137,15 +137,21 @@ def bench_run( ...@@ -137,15 +137,21 @@ def bench_run(
per_out_ch_quant=per_out_ch, per_out_ch_quant=per_out_ch,
) )
fn = mk.FusedMoEModularKernel( moe_config = make_dummy_moe_config(
MoEPrepareAndFinalizeNoEP(), num_experts=num_experts,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=a.dtype,
)
fn = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8( CutlassExpertsFp8(
moe_config=make_dummy_moe_config( moe_config=moe_config,
num_experts=num_experts,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=a.dtype,
),
quant_config=quant_config, quant_config=quant_config,
), ),
) )
......
...@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
...@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import ( ...@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -196,10 +196,21 @@ def bench_run( ...@@ -196,10 +196,21 @@ def bench_run(
g2_alphas=w2_gs, g2_alphas=w2_gs,
) )
kernel = mk.FusedMoEModularKernel( moe_config = make_dummy_moe_config(
MoEPrepareAndFinalizeNoEP(), num_experts=num_experts,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
) )
...@@ -240,11 +251,17 @@ def bench_run( ...@@ -240,11 +251,17 @@ def bench_run(
g1_alphas=w1_gs, g1_alphas=w1_gs,
g2_alphas=w2_gs, g2_alphas=w2_gs,
) )
moe_config = make_dummy_moe_config()
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
) )
......
...@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_experts,
fused_topk, fused_topk,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -131,16 +131,22 @@ def bench_run( ...@@ -131,16 +131,22 @@ def bench_run(
w2_scale=w2_scale, w2_scale=w2_scale,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
) )
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8( CutlassExpertsFp8(
moe_config=make_dummy_moe_config( moe_config=moe_config,
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
),
quant_config=quant_config, quant_config=quant_config,
), ),
) )
...@@ -163,16 +169,22 @@ def bench_run( ...@@ -163,16 +169,22 @@ def bench_run(
w2_scale=w2_scale, w2_scale=w2_scale,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
) )
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8( CutlassExpertsFp8(
moe_config=make_dummy_moe_config( moe_config=moe_config,
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
),
quant_config=quant_config, quant_config=quant_config,
), ),
) )
......
...@@ -17,6 +17,9 @@ from ray.experimental.tqdm_ray import tqdm ...@@ -17,6 +17,9 @@ from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -242,24 +245,33 @@ def benchmark_config( ...@@ -242,24 +245,33 @@ def benchmark_config(
deep_gemm_experts = None deep_gemm_experts = None
if use_deep_gemm: if use_deep_gemm:
deep_gemm_experts = mk.FusedMoEModularKernel( moe_config = (
prepare_finalize=MoEPrepareAndFinalizeNoEP(), FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
num_logical_experts=num_experts,
activation=MoEActivation.SILU,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
)
deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts( fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig( moe_config=moe_config,
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
num_logical_experts=num_experts,
activation=MoEActivation.SILU,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=not disable_inplace(),
) )
with override_config(config): with override_config(config):
...@@ -269,8 +281,16 @@ def benchmark_config( ...@@ -269,8 +281,16 @@ def benchmark_config(
inplace = not disable_inplace() inplace = not disable_inplace()
if use_deep_gemm: if use_deep_gemm:
return deep_gemm_experts( return deep_gemm_experts.apply(
x, w1, w2, topk_weights, topk_ids, inplace=inplace x,
w1,
w2,
topk_weights,
topk_ids,
activation=MoEActivation.SILU,
global_num_experts=num_experts,
apply_router_weight_on_input=False,
expert_map=False,
) )
return fused_experts( return fused_experts(
x, x,
......
...@@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal ...@@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal
The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization. The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization.
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel. The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalizeModular` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists. The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists.
......
...@@ -37,31 +37,31 @@ The rest of the document will focus on the Contiguous / Non-Batched case. Extrap ...@@ -37,31 +37,31 @@ The rest of the document will focus on the Contiguous / Non-Batched case. Extrap
FusedMoEModularKernel splits the FusedMoE operation into 3 parts, FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
1. TopKWeightAndReduce 1. TopKWeightAndReduce
2. FusedMoEPrepareAndFinalize 2. FusedMoEPrepareAndFinalizeModular
3. FusedMoEPermuteExpertsUnpermute 3. FusedMoEExpertsModular
### TopKWeightAndReduce ### TopKWeightAndReduce
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class. The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEExpertsModular` is responsible for the Unpermute and `FusedMoEPrepareAndFinalizeModular` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEExpertsModular`. But some implementations choose to do it `FusedMoEPrepareAndFinalizeModular`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py). Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py).
`FusedMoEPrepareAndFinalize::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method. `FusedMoEPrepareAndFinalizeModular::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method.
The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExpertsUnpermute` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens. The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEExpertsModular` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens.
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEPermuteExpertsUnpermute` implementation does the weight application and reduction itself. * `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEExpertsModular` implementation does the weight application and reduction itself.
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEPermuteExpertsUnpermute` implementation needs the `FusedMoEPrepareAndFinalize::finalize()` to do the weight application and reduction. * `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEExpertsModular` implementation needs the `FusedMoEPrepareAndFinalizeModular::finalize()` to do the weight application and reduction.
### FusedMoEPrepareAndFinalize ### FusedMoEPrepareAndFinalizeModular
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. The `FusedMoEPrepareAndFinalizeModular` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalizeModular` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
![FusedMoEPrepareAndFinalize Blocks](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png) ![FusedMoEPrepareAndFinalizeModular Blocks](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png)
### FusedMoEPermuteExpertsUnpermute ### FusedMoEExpertsModular
The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operations happen. The `FusedMoEPermuteExpertsUnpermute` abstract class exposes a few important functions, The `FusedMoEExpertsModular` class is where the crux of the MoE operations happen. The `FusedMoEExpertsModular` abstract class exposes a few important functions,
* apply() * apply()
* workspace_shapes() * workspace_shapes()
...@@ -81,25 +81,25 @@ The `apply` method is where the implementations perform ...@@ -81,25 +81,25 @@ The `apply` method is where the implementations perform
#### workspace_shapes() #### workspace_shapes()
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEPermuteExpertsUnpermute::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation. The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEExpertsModular::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
#### finalize_weight_and_reduce_impl() #### finalize_weight_and_reduce_impl()
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section. It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEExpertsModular::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use. `FusedMoEExpertsModular::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalizeModular::finalize()` to use.
![FusedMoEPermuteExpertsUnpermute Blocks](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png) ![FusedMoEExpertsModular Blocks](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png)
### FusedMoEModularKernel ### FusedMoEModularKernel
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` objects. `FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` objects.
`FusedMoEModularKernel` pseudocode/sketch, `FusedMoEModularKernel` pseudocode/sketch,
```py ```py
class FusedMoEModularKernel: class FusedMoEModularKernel:
def __init__(self, def __init__(self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
fused_experts: FusedMoEPermuteExpertsUnpermute): fused_experts: FusedMoEExpertsModular):
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
...@@ -128,53 +128,53 @@ class FusedMoEModularKernel: ...@@ -128,53 +128,53 @@ class FusedMoEModularKernel:
## How-To ## How-To
### How To Add a FusedMoEPrepareAndFinalize Type ### How To Add a FusedMoEPrepareAndFinalizeModular Type
Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, Typically a FusedMoEPrepareAndFinalizeModular type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and
* DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
#### Step 1: Add an All2All manager #### Step 1: Add an All2All manager
The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py). The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalizeModular` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py).
#### Step 2: Add a FusedMoEPrepareAndFinalize Type #### Step 2: Add a FusedMoEPrepareAndFinalizeModular Type
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize` abstract class. This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalizeModular` abstract class.
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. `FusedMoEPrepareAndFinalizeModular::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. `FusedMoEPrepareAndFinalizeModular::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. `FusedMoEPrepareAndFinalizeModular::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalizeModular::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. `FusedMoEPrepareAndFinalizeModular::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
`FusedMoEPrepareAndFinalize::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None. `FusedMoEPrepareAndFinalizeModular::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None.
`FusedMoEPrepareAndFinalize::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once. `FusedMoEPrepareAndFinalizeModular::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
`FusedMoEPrepareAndFinalize::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank(). `FusedMoEPrepareAndFinalizeModular::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementation that matches your All2All implementation closely and using it as a reference. We suggest picking an already existing `FusedMoEPrepareAndFinalizeModular` implementation that matches your All2All implementation closely and using it as a reference.
### How To Add a FusedMoEPermuteExpertsUnpermute Type ### How To Add a FusedMoEExpertsModular Type
FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows, FusedMoEExpertsModular performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. `FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
`FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically `FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not. implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
`FusedMoEPermuteExpertsUnpermute::supports_expert_map()`: Return True if the implementation supports expert map. `FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
`FusedMoEPermuteExpertsUnpermute::workspace_shapes()` / `FusedMoEExpertsModular::workspace_shapes()` /
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` / `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` /
`FusedMoEPermuteExpertsUnpermute::apply`: Refer to `FusedMoEPermuteExpertsUnpermute` section above. `FusedMoEExpertsModular::apply`: Refer to `FusedMoEExpertsModular` section above.
### FusedMoEModularKernel Initialization ### FusedMoEModularKernel Initialization
...@@ -186,14 +186,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ...@@ -186,14 +186,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
#### maybe_make_prepare_finalize #### maybe_make_prepare_finalize
The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalizeModular` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalizeModular` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
Please refer to the implementations in, Please refer to the implementations in,
* `ModelOptNvFp4FusedMoE` * `ModelOptNvFp4FusedMoE`
#### select_gemm_impl #### select_gemm_impl
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEExpertsModular` object.
Please refer to the implementations in, Please refer to the implementations in,
* `UnquantizedFusedMoEMethod` * `UnquantizedFusedMoEMethod`
...@@ -205,7 +205,7 @@ derived classes. ...@@ -205,7 +205,7 @@ derived classes.
#### init_prepare_finalize #### init_prepare_finalize
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalize` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEPermuteExpertsUnpermute` object and builds the `FusedMoEModularKernel` object Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalizeModular` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEExpertsModular` object and builds the `FusedMoEModularKernel` object
Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188). Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188).
**Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used. **Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
...@@ -214,9 +214,9 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl ...@@ -214,9 +214,9 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl
We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py). We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py).
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are The unit test iterates through all combinations of `FusedMoEPrepareAndFinalizeModular` and `FusedMoEPremuteExpertsUnpermute` types and if they are
compatible, runs some correctness tests. compatible, runs some correctness tests.
If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnpermute` implementations, If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsModular` implementations,
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively. 1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`, 2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
...@@ -225,24 +225,24 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp ...@@ -225,24 +225,24 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp
Doing this will add the new implementation to the test suite. Doing this will add the new implementation to the test suite.
### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility ### How To Check `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` Compatibility
The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts` Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked As a side effect, this script can be used to test `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` compatibility. When invoked
with incompatible types, the script will error. with incompatible types, the script will error.
### How To Profile ### How To Profile
Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py) Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py)
The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible
`FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types. `FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` types.
Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts` Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
## FusedMoEPrepareAndFinalize Implementations ## FusedMoEPrepareAndFinalizeModular Implementations
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses. See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses.
## FusedMoEPermuteExpertsUnpermute ## FusedMoEExpertsModular
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts. See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts.
...@@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel ...@@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel
## Fused MoE Modular All2All backends ## Fused MoE Modular All2All backends
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend. There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalizeModular` subclasses provide an interface for each all2all backend.
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support. The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document. The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalizeModular` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16. The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalizeModular` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step). Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag. Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalizeModular` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP. Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
...@@ -36,8 +36,6 @@ th { ...@@ -36,8 +36,6 @@ th {
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | | flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
!!! info "Table key" !!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8 1. All types: mxfp4, nvfp4, int4, int8, fp8
...@@ -75,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu ...@@ -75,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts. As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`. Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEExpertsModular`.
To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats. To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source | | Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------| |--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
...@@ -106,7 +104,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels ...@@ -106,7 +104,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts. The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | | backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses |
|---------|-----------------------------------------|----------------------------------------------| |---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` | | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` | | deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
......
...@@ -17,13 +17,13 @@ from .mk_objects import ( ...@@ -17,13 +17,13 @@ from .mk_objects import (
def make_config_arg_parser(description: str): def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES: for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s: if pf.__name__ == s:
return pf return pf
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}") raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular:
for fe in MK_FUSED_EXPERT_TYPES: for fe in MK_FUSED_EXPERT_TYPES:
if fe.__name__ == s: if fe.__name__ == s:
return fe return fe
......
...@@ -66,7 +66,7 @@ class Config: ...@@ -66,7 +66,7 @@ class Config:
quant_config: TestMoEQuantConfig | None quant_config: TestMoEQuantConfig | None
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute fused_experts_type: mk.FusedMoEExperts
fused_moe_chunk_size: int | None fused_moe_chunk_size: int | None
world_size: int world_size: int
...@@ -566,7 +566,7 @@ def make_modular_kernel( ...@@ -566,7 +566,7 @@ def make_modular_kernel(
config: Config, config: Config,
vllm_config: VllmConfig, vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEKernel:
def next_power_of_2(x): def next_power_of_2(x):
import math import math
...@@ -613,7 +613,7 @@ def make_modular_kernel( ...@@ -613,7 +613,7 @@ def make_modular_kernel(
config.N, config.N,
) )
modular_kernel = mk.FusedMoEModularKernel( modular_kernel = mk.FusedMoEKernel(
prepare_finalize=prepare_finalize, prepare_finalize=prepare_finalize,
fused_experts=fused_experts, fused_experts=fused_experts,
inplace=False, inplace=False,
...@@ -667,6 +667,7 @@ def run_modular_kernel( ...@@ -667,6 +667,7 @@ def run_modular_kernel(
"w2": rank_weights.w2, "w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights, "topk_weights": rank_tensors.topk_weights,
"topk_ids": topk_ids, "topk_ids": topk_ids,
"activation": MoEActivation.SILU,
"expert_map": rank_tensors.expert_map, "expert_map": rank_tensors.expert_map,
"global_num_experts": config.E, "global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1 "apply_router_weight_on_input": config.topk == 1
...@@ -684,6 +685,6 @@ def run_modular_kernel( ...@@ -684,6 +685,6 @@ def run_modular_kernel(
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
out = mk.forward(**mk_kwargs) out = mk.apply(**mk_kwargs)
return out return out
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts, NaiveBatchedExperts,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
) )
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
...@@ -71,12 +71,14 @@ class ExpertInfo: ...@@ -71,12 +71,14 @@ class ExpertInfo:
needs_aiter: bool = False needs_aiter: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {} PREPARE_FINALIZE_INFO: dict[
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] ] = {}
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {}
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = []
standard_format = mk.FusedMoEActivationFormat.Standard standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts batched_format = mk.FusedMoEActivationFormat.BatchedExperts
...@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo: ...@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo:
register_prepare_and_finalize( register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
standard_format, standard_format,
common_float_types, common_float_types,
blocked_quantization_support=True, blocked_quantization_support=True,
...@@ -239,14 +241,14 @@ if has_mori(): ...@@ -239,14 +241,14 @@ if has_mori():
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize, FlashInferA2APrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
register_prepare_and_finalize( register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize, FlashInferA2APrepareAndFinalize,
standard_format, standard_format,
nvfp4_types + fp8_types, nvfp4_types + fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
...@@ -430,12 +432,12 @@ def make_cutlass_strides( ...@@ -430,12 +432,12 @@ def make_cutlass_strides(
def make_fused_experts( def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, fused_experts_type: mk.FusedMoEExpertsModular,
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
num_dispatchers: int, num_dispatchers: int,
N: int, N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
if ( if (
fused_experts_type.activation_format() fused_experts_type.activation_format()
== mk.FusedMoEActivationFormat.BatchedExperts == mk.FusedMoEActivationFormat.BatchedExperts
......
...@@ -72,7 +72,7 @@ def profile_modular_kernel( ...@@ -72,7 +72,7 @@ def profile_modular_kernel(
"apply_router_weight_on_input": config.topk == 1, "apply_router_weight_on_input": config.topk == 1,
} }
do_profile(mk.forward, mk_kwargs, pgi, config) do_profile(mk.apply, mk_kwargs, pgi, config)
def rank_worker( def rank_worker(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts, BatchedDeepGemmExperts,
) )
...@@ -12,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( ...@@ -12,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedPrepareAndFinalize,
BatchedTritonExperts, BatchedTritonExperts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights from .test_deepgemm import make_block_quant_fp8_weights
...@@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton( ...@@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config, quant_config=quant_config,
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
) )
mk_triton = FusedMoEModularKernel( mk_triton = FusedMoEKernel(
prep_finalize, prep_finalize,
triton_experts, triton_experts,
inplace=False, inplace=False,
) )
out_triton = mk_triton( out_triton = mk_triton.apply(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E, global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
) )
# deepgemm # deepgemm
...@@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton( ...@@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config, quant_config=quant_config,
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
) )
mk_deepgemm = FusedMoEModularKernel( mk_deepgemm = FusedMoEKernel(
prep_finalize, prep_finalize,
deepgemm_experts, deepgemm_experts,
inplace=False, inplace=False,
) )
out_deepgemm = mk_deepgemm( out_deepgemm = mk_deepgemm.apply(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E, global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
) )
diff = calc_diff(out_deepgemm, out_triton) diff = calc_diff(out_deepgemm, out_triton)
......
...@@ -21,15 +21,16 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -21,15 +21,16 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts, fused_experts,
fused_topk, fused_topk,
) )
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, _valid_deep_gemm_shape,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
...@@ -193,7 +194,17 @@ def test_w8a8_block_fp8_fused_moe( ...@@ -193,7 +194,17 @@ def test_w8a8_block_fp8_fused_moe(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
) )
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) m_out = m_fused_moe.apply(
a,
w1,
w2,
topk_weights,
topk_ids,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
global_num_experts=w1.shape[0],
)
# 0.039 only needed for M >= 8192 # 0.039 only needed for M >= 8192
tol = 0.035 if M < 8192 else 0.039 tol = 0.035 if M < 8192 else 0.039
...@@ -252,23 +263,33 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) ...@@ -252,23 +263,33 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
w2_scale=w2_s, w2_scale=w2_s,
block_shape=block_size, block_shape=block_size,
) )
moe_config = make_dummy_moe_config()
deep_gemm_experts = mk.FusedMoEModularKernel( deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(), prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts( fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(), moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=False, inplace=False,
) )
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
return deep_gemm_experts( return deep_gemm_experts.apply(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
global_num_experts=E,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=False,
) )
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
......
...@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops ...@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -22,9 +25,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import ( ...@@ -22,9 +25,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8, CutlassExpertsFp8,
run_cutlass_moe_fp8, run_cutlass_moe_fp8,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -197,20 +197,26 @@ def run_with_expert_maps( ...@@ -197,20 +197,26 @@ def run_with_expert_maps(
for kwargs, new_quant_config in slice_experts(): for kwargs, new_quant_config in slice_experts():
w2 = kwargs["w2"] w2 = kwargs["w2"]
a = kwargs["hidden_states"] a = kwargs["hidden_states"]
kernel = mk.FusedMoEModularKernel( moe_config = make_dummy_moe_config(
MoEPrepareAndFinalizeNoEP(), num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=new_quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8( CutlassExpertsFp8(
moe_config=make_dummy_moe_config( moe_config=moe_config,
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
),
quant_config=new_quant_config, quant_config=new_quant_config,
), ),
inplace=False, inplace=False,
) )
out_tensor = out_tensor + kernel(**kwargs) out_tensor = out_tensor + kernel.apply(**kwargs)
return out_tensor return out_tensor
...@@ -252,25 +258,35 @@ def run_8_bit( ...@@ -252,25 +258,35 @@ def run_8_bit(
"w2": moe_tensors.w2_q, # type: ignore[union-attr] "w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights, "topk_weights": topk_weights,
"topk_ids": topk_ids, "topk_ids": topk_ids,
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
"activation": MoEActivation.SILU,
"expert_map": None,
"apply_router_weight_on_input": False,
} }
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined] num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep: if not with_ep:
kernel = mk.FusedMoEModularKernel( moe_config = make_dummy_moe_config(
MoEPrepareAndFinalizeNoEP(), num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8( CutlassExpertsFp8(
moe_config=make_dummy_moe_config( moe_config=moe_config,
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
),
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=False, inplace=False,
) )
return kernel(**kwargs) return kernel.apply(**kwargs)
assert num_local_experts is not None assert num_local_experts is not None
return run_with_expert_maps( return run_with_expert_maps(
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
...@@ -170,7 +170,7 @@ def make_ll_modular_kernel( ...@@ -170,7 +170,7 @@ def make_ll_modular_kernel(
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
test_config: TestConfig, test_config: TestConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel: ) -> FusedMoEKernel:
assert test_config.low_latency assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None assert test_config.use_fp8_dispatch is not None
...@@ -195,7 +195,7 @@ def make_ll_modular_kernel( ...@@ -195,7 +195,7 @@ def make_ll_modular_kernel(
quant_config=quant_config, quant_config=quant_config,
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
) )
return FusedMoEModularKernel( return FusedMoEKernel(
prepare_finalize=a2a, prepare_finalize=a2a,
fused_experts=fused_experts, fused_experts=fused_experts,
inplace=False, inplace=False,
...@@ -210,7 +210,7 @@ def make_ht_modular_kernel( ...@@ -210,7 +210,7 @@ def make_ht_modular_kernel(
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
test_config: TestConfig, test_config: TestConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel: ) -> FusedMoEKernel:
assert not test_config.low_latency assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None assert test_config.use_fp8_dispatch is None
...@@ -228,7 +228,7 @@ def make_ht_modular_kernel( ...@@ -228,7 +228,7 @@ def make_ht_modular_kernel(
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
) )
return FusedMoEModularKernel( return FusedMoEKernel(
prepare_finalize=a2a, prepare_finalize=a2a,
fused_experts=fused_experts, fused_experts=fused_experts,
inplace=False, inplace=False,
...@@ -242,11 +242,11 @@ def make_modular_kernel( ...@@ -242,11 +242,11 @@ def make_modular_kernel(
num_local_experts: int, num_local_experts: int,
test_tensors: TestTensors, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel: ) -> FusedMoEKernel:
q_dtype = torch.float8_e4m3fn q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config test_config = test_tensors.config
mk: FusedMoEModularKernel mk: FusedMoEKernel
# Make modular kernel # Make modular kernel
if test_config.low_latency: if test_config.low_latency:
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0))) max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
...@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl( ...@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl(
) )
# Make modular kernel # Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel( mk: FusedMoEKernel = make_modular_kernel(
pg=pg, pg=pg,
pgi=pgi, pgi=pgi,
dp_size=dp_size, dp_size=dp_size,
...@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl( ...@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl(
with with_dp_metadata( with with_dp_metadata(
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
): ):
out = mk.forward( out = mk.apply(
hidden_states=test_tensors.rank_tokens, hidden_states=test_tensors.rank_tokens,
w1=w1, w1=w1,
w2=w2, w2=w2,
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
...@@ -135,7 +135,7 @@ def make_modular_kernel( ...@@ -135,7 +135,7 @@ def make_modular_kernel(
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel: ) -> FusedMoEKernel:
ht_args: DeepEPHTArgs | None = None ht_args: DeepEPHTArgs | None = None
ll_args: DeepEPLLArgs | None = None ll_args: DeepEPLLArgs | None = None
...@@ -180,7 +180,7 @@ def make_modular_kernel( ...@@ -180,7 +180,7 @@ def make_modular_kernel(
quant_config=quant_config, quant_config=quant_config,
) )
mk = FusedMoEModularKernel( mk = FusedMoEKernel(
prepare_finalize=a2a, prepare_finalize=a2a,
fused_experts=fused_experts, fused_experts=fused_experts,
inplace=False, inplace=False,
...@@ -242,7 +242,7 @@ def deep_ep_moe_impl( ...@@ -242,7 +242,7 @@ def deep_ep_moe_impl(
) )
# Make modular kernel # Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel( mk: FusedMoEKernel = make_modular_kernel(
pg, pg,
pgi, pgi,
low_latency_mode, low_latency_mode,
...@@ -255,7 +255,7 @@ def deep_ep_moe_impl( ...@@ -255,7 +255,7 @@ def deep_ep_moe_impl(
quant_config, quant_config,
) )
out = mk.forward( out = mk.apply(
hidden_states=rank_tokens_chunk, hidden_states=rank_tokens_chunk,
w1=w1, w1=w1,
w2=w2, w2=w2,
......
...@@ -14,13 +14,16 @@ import torch ...@@ -14,13 +14,16 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option) # vLLM fused-expert reference (Triton fallback + DeepGEMM option)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
...@@ -108,11 +111,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -108,11 +111,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
a1_scale=a1_scale, a1_scale=a1_scale,
block_shape=block_size, block_shape=block_size,
) )
moe_config = make_dummy_moe_config()
deep_gemm_experts = mk.FusedMoEModularKernel( deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(), prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts( fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(), moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=False, inplace=False,
...@@ -130,12 +139,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -130,12 +139,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
) )
# DeepGemm # DeepGemm
out_deepgemm = deep_gemm_experts( out_deepgemm = deep_gemm_experts.apply(
hidden_states=tokens_bf16, hidden_states=tokens_bf16,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
global_num_experts=num_experts,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
) )
diff = calc_diff(out_deepgemm, out_triton) diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}" assert diff < 0.001, f"Diff exceeded 1%: {diff}"
......
...@@ -8,6 +8,9 @@ import torch ...@@ -8,6 +8,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe, rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
swap_w13_to_w31, swap_w13_to_w31,
) )
...@@ -115,6 +116,7 @@ class TestData: ...@@ -115,6 +116,7 @@ class TestData:
e: int, e: int,
is_trtllm: bool, is_trtllm: bool,
activation: MoEActivation = MoEActivation.SILU, activation: MoEActivation = MoEActivation.SILU,
topk: int = 1,
) -> "TestData": ) -> "TestData":
is_gated = activation.is_gated is_gated = activation.is_gated
...@@ -152,13 +154,6 @@ class TestData: ...@@ -152,13 +154,6 @@ class TestData:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight, is_gated layer.w13_weight, layer.w2_weight, is_gated
) )
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.routing_method_type = RoutingMethodType.Llama4 layer.routing_method_type = RoutingMethodType.Llama4
layer.renormalize = False layer.renormalize = False
...@@ -166,6 +161,21 @@ class TestData: ...@@ -166,6 +161,21 @@ class TestData:
layer.ep_rank = 0 layer.ep_rank = 0
layer.local_num_experts = e layer.local_num_experts = e
layer.moe = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
moe_parallel_config=layer.moe_parallel_config,
in_dtype=hidden_states.dtype,
is_act_and_mul=is_gated,
routing_method=layer.routing_method_type,
activation=activation,
device=w13_quantized.device,
)
return TestData( return TestData(
hidden_states=hidden_states, hidden_states=hidden_states,
w13_quantized=w13_quantized, w13_quantized=w13_quantized,
...@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config, quant_config=quant_config,
) )
flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe( kernel = mk.FusedMoEKernel(
layer=td.layer, maybe_make_prepare_finalize(
moe=td.layer.moe,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=True,
),
TrtLlmFp8Experts(
moe_config=td.layer.moe,
quant_config=quant_config,
),
)
flashinfer_output = kernel.apply_monolithic(
hidden_states=td.hidden_states, hidden_states=td.hidden_states,
w1=td.layer.w13_weight,
w2=td.layer.w2_weight,
router_logits=score, router_logits=score,
routing_bias=None, activation=activation,
global_num_experts=e, global_num_experts=e,
top_k=topk, expert_map=None,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
routed_scaling_factor=1.0,
) )
check_accuracy( check_accuracy(
...@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferExperts( FlashInferExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
inplace=False, inplace=False,
) )
flashinfer_cutlass_output = kernel( flashinfer_cutlass_output = kernel.apply(
td.hidden_states, td.hidden_states,
td.layer.w13_weight, td.layer.w13_weight,
td.layer.w2_weight, td.layer.w2_weight,
......
...@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops ...@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe, is_valid_flashinfer_cutlass_fused_moe,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -107,19 +107,27 @@ def test_flashinfer_fp4_moe_no_graph( ...@@ -107,19 +107,27 @@ def test_flashinfer_fp4_moe_no_graph(
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
) )
flashinfer_experts = FusedMoEModularKernel( flashinfer_experts = FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config), FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
inplace=False, inplace=False,
) )
flashinfer_output = flashinfer_experts( flashinfer_output = flashinfer_experts.apply(
hidden_states=a, hidden_states=a,
w1=w1_q, w1=w1_q,
w2=w2_q, w2=w2_q,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=False,
) )
# Reference check: # Reference check:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment