Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
......@@ -44,7 +44,8 @@ steps:
- vllm/envs.py
- vllm/config
commands:
- pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
- pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 2
- label: Kernels Mamba Test
......
......@@ -12,12 +12,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
......@@ -137,15 +137,21 @@ def bench_run(
per_out_ch_quant=per_out_ch,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
moe_config=make_dummy_moe_config(
moe_config = make_dummy_moe_config(
num_experts=num_experts,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=a.dtype,
)
fn = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8(
moe_config=moe_config,
quant_config=quant_config,
),
)
......
......@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
......@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
......@@ -196,10 +196,21 @@ def bench_run(
g2_alphas=w2_gs,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
moe_config = make_dummy_moe_config(
num_experts=num_experts,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp4(
make_dummy_moe_config(),
moe_config=moe_config,
quant_config=quant_config,
),
)
......@@ -240,11 +251,17 @@ def bench_run(
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
moe_config = make_dummy_moe_config()
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp4(
make_dummy_moe_config(),
moe_config=moe_config,
quant_config=quant_config,
),
)
......
......@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
......@@ -131,16 +131,22 @@ def bench_run(
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
moe_config=make_dummy_moe_config(
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
fn = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8(
moe_config=moe_config,
quant_config=quant_config,
),
)
......@@ -163,16 +169,22 @@ def bench_run(
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
moe_config=make_dummy_moe_config(
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
fn = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8(
moe_config=moe_config,
quant_config=quant_config,
),
)
......
......@@ -17,6 +17,9 @@ from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -242,10 +245,8 @@ def benchmark_config(
deep_gemm_experts = None
if use_deep_gemm:
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig(
moe_config = (
FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
......@@ -258,8 +259,19 @@ def benchmark_config(
routing_method=RoutingMethodType.TopK,
device="cuda",
),
)
deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not disable_inplace(),
)
with override_config(config):
......@@ -269,8 +281,16 @@ def benchmark_config(
inplace = not disable_inplace()
if use_deep_gemm:
return deep_gemm_experts(
x, w1, w2, topk_weights, topk_ids, inplace=inplace
return deep_gemm_experts.apply(
x,
w1,
w2,
topk_weights,
topk_ids,
activation=MoEActivation.SILU,
global_num_experts=num_experts,
apply_router_weight_on_input=False,
expert_map=False,
)
return fused_experts(
x,
......
......@@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal
The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization.
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalizeModular` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists.
......
This diff is collapsed.
......@@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel
## Fused MoE Modular All2All backends
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend.
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalizeModular` subclasses provide an interface for each all2all backend.
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalizeModular` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalizeModular` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalizeModular` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
......@@ -36,8 +36,6 @@ th {
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
......@@ -75,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`.
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEExpertsModular`.
To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
......@@ -106,7 +104,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
| backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses |
|---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
......
......@@ -17,13 +17,13 @@ from .mk_objects import (
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular:
for fe in MK_FUSED_EXPERT_TYPES:
if fe.__name__ == s:
return fe
......
......@@ -66,7 +66,7 @@ class Config:
quant_config: TestMoEQuantConfig | None
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
fused_experts_type: mk.FusedMoEExperts
fused_moe_chunk_size: int | None
world_size: int
......@@ -566,7 +566,7 @@ def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
) -> mk.FusedMoEKernel:
def next_power_of_2(x):
import math
......@@ -613,7 +613,7 @@ def make_modular_kernel(
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
modular_kernel = mk.FusedMoEKernel(
prepare_finalize=prepare_finalize,
fused_experts=fused_experts,
inplace=False,
......@@ -667,6 +667,7 @@ def run_modular_kernel(
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": topk_ids,
"activation": MoEActivation.SILU,
"expert_map": rank_tensors.expert_map,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1
......@@ -684,6 +685,6 @@ def run_modular_kernel(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)
out = mk.apply(**mk_kwargs)
return out
......@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
......@@ -71,12 +71,14 @@ class ExpertInfo:
needs_aiter: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
PREPARE_FINALIZE_INFO: dict[
mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo
] = {}
EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
......@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo:
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
standard_format,
common_float_types,
blocked_quantization_support=True,
......@@ -239,14 +241,14 @@ if has_mori():
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
FlashInferA2APrepareAndFinalize,
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
......@@ -430,12 +432,12 @@ def make_cutlass_strides(
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
fused_experts_type: mk.FusedMoEExpertsModular,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
fused_experts_type.activation_format()
== mk.FusedMoEActivationFormat.BatchedExperts
......
......@@ -72,7 +72,7 @@ def profile_modular_kernel(
"apply_router_weight_on_input": config.topk == 1,
}
do_profile(mk.forward, mk_kwargs, pgi, config)
do_profile(mk.apply, mk_kwargs, pgi, config)
def rank_worker(
......
......@@ -4,6 +4,7 @@
import pytest
import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
......@@ -12,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
......@@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(
mk_triton = FusedMoEKernel(
prep_finalize,
triton_experts,
inplace=False,
)
out_triton = mk_triton(
out_triton = mk_triton.apply(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
)
# deepgemm
......@@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(
mk_deepgemm = FusedMoEKernel(
prep_finalize,
deepgemm_experts,
inplace=False,
)
out_deepgemm = mk_deepgemm(
out_deepgemm = mk_deepgemm.apply(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
)
diff = calc_diff(out_deepgemm, out_triton)
......
......@@ -21,15 +21,16 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
......@@ -193,7 +194,17 @@ def test_w8a8_block_fp8_fused_moe(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
m_out = m_fused_moe.apply(
a,
w1,
w2,
topk_weights,
topk_ids,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
global_num_experts=w1.shape[0],
)
# 0.039 only needed for M >= 8192
tol = 0.035 if M < 8192 else 0.039
......@@ -252,23 +263,33 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
w2_scale=w2_s,
block_shape=block_size,
)
moe_config = make_dummy_moe_config()
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
return deep_gemm_experts(
return deep_gemm_experts.apply(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=E,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=False,
)
# Set the context to avoid lots of warning spam.
......
......@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
......@@ -22,9 +25,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
......@@ -197,20 +197,26 @@ def run_with_expert_maps(
for kwargs, new_quant_config in slice_experts():
w2 = kwargs["w2"]
a = kwargs["hidden_states"]
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
moe_config=make_dummy_moe_config(
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=new_quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8(
moe_config=moe_config,
quant_config=new_quant_config,
),
inplace=False,
)
out_tensor = out_tensor + kernel(**kwargs)
out_tensor = out_tensor + kernel.apply(**kwargs)
return out_tensor
......@@ -252,25 +258,35 @@ def run_8_bit(
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
"activation": MoEActivation.SILU,
"expert_map": None,
"apply_router_weight_on_input": False,
}
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
moe_config=make_dummy_moe_config(
moe_config = make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
)
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp8(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
return kernel(**kwargs)
return kernel.apply(**kwargs)
assert num_local_experts is not None
return run_with_expert_maps(
......
......@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
......@@ -170,7 +170,7 @@ def make_ll_modular_kernel(
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
......@@ -195,7 +195,7 @@ def make_ll_modular_kernel(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
return FusedMoEModularKernel(
return FusedMoEKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
......@@ -210,7 +210,7 @@ def make_ht_modular_kernel(
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
......@@ -228,7 +228,7 @@ def make_ht_modular_kernel(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
)
return FusedMoEModularKernel(
return FusedMoEKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
......@@ -242,11 +242,11 @@ def make_modular_kernel(
num_local_experts: int,
test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
mk: FusedMoEKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
......@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl(
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
mk: FusedMoEKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
......@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl(
with with_dp_metadata(
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
):
out = mk.forward(
out = mk.apply(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
......
......@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
......@@ -135,7 +135,7 @@ def make_modular_kernel(
q_dtype: torch.dtype | None,
use_fp8_dispatch: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
ht_args: DeepEPHTArgs | None = None
ll_args: DeepEPLLArgs | None = None
......@@ -180,7 +180,7 @@ def make_modular_kernel(
quant_config=quant_config,
)
mk = FusedMoEModularKernel(
mk = FusedMoEKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
......@@ -242,7 +242,7 @@ def deep_ep_moe_impl(
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
mk: FusedMoEKernel = make_modular_kernel(
pg,
pgi,
low_latency_mode,
......@@ -255,7 +255,7 @@ def deep_ep_moe_impl(
quant_config,
)
out = mk.forward(
out = mk.apply(
hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
......
......@@ -14,13 +14,16 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
......@@ -108,11 +111,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
a1_scale=a1_scale,
block_shape=block_size,
)
moe_config = make_dummy_moe_config()
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
......@@ -130,12 +139,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
)
# DeepGemm
out_deepgemm = deep_gemm_experts(
out_deepgemm = deep_gemm_experts.apply(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=num_experts,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
......
......@@ -8,6 +8,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
swap_w13_to_w31,
)
......@@ -115,6 +116,7 @@ class TestData:
e: int,
is_trtllm: bool,
activation: MoEActivation = MoEActivation.SILU,
topk: int = 1,
) -> "TestData":
is_gated = activation.is_gated
......@@ -152,13 +154,6 @@ class TestData:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight, is_gated
)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.routing_method_type = RoutingMethodType.Llama4
layer.renormalize = False
......@@ -166,6 +161,21 @@ class TestData:
layer.ep_rank = 0
layer.local_num_experts = e
layer.moe = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
moe_parallel_config=layer.moe_parallel_config,
in_dtype=hidden_states.dtype,
is_act_and_mul=is_gated,
routing_method=layer.routing_method_type,
activation=activation,
device=w13_quantized.device,
)
return TestData(
hidden_states=hidden_states,
w13_quantized=w13_quantized,
......@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config,
)
flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
layer=td.layer,
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=td.layer.moe,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=True,
),
TrtLlmFp8Experts(
moe_config=td.layer.moe,
quant_config=quant_config,
),
)
flashinfer_output = kernel.apply_monolithic(
hidden_states=td.hidden_states,
w1=td.layer.w13_weight,
w2=td.layer.w2_weight,
router_logits=score,
routing_bias=None,
activation=activation,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
expert_map=None,
apply_router_weight_on_input=True,
routed_scaling_factor=1.0,
)
check_accuracy(
......@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
routing_method=RoutingMethodType.TopK,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
......@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
inplace=False,
)
flashinfer_cutlass_output = kernel(
flashinfer_cutlass_output = kernel.apply(
td.hidden_states,
td.layer.w13_weight,
td.layer.w2_weight,
......
......@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
......@@ -107,19 +107,27 @@ def test_flashinfer_fp4_moe_no_graph(
routing_method=RoutingMethodType.TopK,
)
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
flashinfer_experts = FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
inplace=False,
)
flashinfer_output = flashinfer_experts(
flashinfer_output = flashinfer_experts.apply(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=False,
)
# Reference check:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment