Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
......@@ -3,3 +3,5 @@ accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "0"
......@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
......@@ -574,10 +575,14 @@ def make_modular_kernel(
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
activation="silu",
device=vllm_config.device_config.device,
routing_method=RoutingMethodType.DeepSeekV3,
)
# make modular kernel
......
......@@ -425,84 +425,26 @@ def make_fused_experts(
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"quant_config": quant_config,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"out_dtype": moe.in_dtype,
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
if (
fused_experts_type.activation_format()
== mk.FusedMoEActivationFormat.BatchedExperts
):
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"moe_config": moe,
"quant_config": quant_config,
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
} | quant_kwargs
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
kwargs = {
"out_dtype": moe.in_dtype,
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
} | quant_kwargs
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
}
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
kwargs = {
"moe_config": moe,
"quant_config": quant_config,
}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
print(f"Making {fused_experts_type.__class__.__name__} {kwargs} ...")
experts = fused_experts_type(**kwargs)
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
from .utils import make_dummy_moe_config
BLOCK_SIZE = [128, 128]
......@@ -71,6 +72,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
......@@ -89,6 +91,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
......
......@@ -4,7 +4,12 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import (
make_dummy_moe_config,
make_test_quant_config,
make_test_weights,
)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
......@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
......@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(quant_config)
m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
......@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
return deep_gemm_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
......
......@@ -8,6 +8,7 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
......@@ -193,16 +194,18 @@ def run_with_expert_maps(
out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
for kwargs, new_quant_config in slice_experts():
w2 = kwargs["w2"]
a = kwargs["hidden_states"]
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=kwargs["hidden_states"].dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=kwargs["w2"].shape[0], # type: ignore[union-attr]
n=kwargs["w2"].shape[2], # type: ignore[union-attr]
k=kwargs["w2"].shape[1], # type: ignore[union-attr]
moe_config=make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
),
quant_config=new_quant_config,
device="cuda",
),
)
out_tensor = out_tensor + kernel(**kwargs)
......@@ -249,19 +252,19 @@ def run_8_bit(
"topk_ids": topk_ids,
}
num_experts = moe_tensors.w1.size(0)
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=moe_tensors.a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
moe_config=make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
),
quant_config=quant_config,
device="cuda",
),
)
return kernel(**kwargs)
......
......@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
from .utils import make_dummy_moe_config, make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
......@@ -192,6 +192,7 @@ def make_ll_modular_kernel(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
......@@ -219,7 +220,10 @@ def make_ht_modular_kernel(
block_shape=test_config.block_size,
)
fused_experts = DeepGemmExperts(quant_config)
fused_experts = DeepGemmExperts(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
......@@ -349,9 +353,6 @@ def triton_impl(
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False,
)
......
......@@ -10,11 +10,14 @@ import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
......@@ -160,15 +163,21 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size
moe_config = make_dummy_moe_config()
if low_latency_mode:
assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers,
moe_config=moe_config,
quant_config=quant_config,
)
else:
fused_experts = TritonExperts(quant_config=quant_config)
fused_experts = TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
......
......@@ -11,10 +11,19 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
......@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size,
)
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
# triton reference
out_triton = fused_experts(
hidden_states=tokens_bf16,
......@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
allow_deep_gemm=False,
)
# DeepGemm
out_deepgemm = fused_experts(
out_deepgemm = deep_gemm_experts(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
allow_deep_gemm=True,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
......@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe"
)
_DeepGemmExperts = importlib.import_module(
"vllm.model_executor.layers.fused_moe.deep_gemm_moe"
).DeepGemmExperts
call_counter = {"cnt": 0}
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
orig_fn = _DeepGemmExperts.apply
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
def _spy_apply(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
......
......@@ -8,7 +8,10 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
......@@ -116,18 +119,7 @@ class TestData:
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
# Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
tp_size=1,
pcp_size=1,
dp_size=1,
ep_size=1,
tp_rank=0,
pcp_rank=0,
dp_rank=0,
ep_rank=0,
use_ep=False,
all2all_backend="naive",
)
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
......@@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
assert activation in ["silu", "relu2_no_mul"]
is_act_and_mul = activation == "silu_and_mul"
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
......@@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
is_act_and_mul=is_act_and_mul,
routing_method=RoutingMethodType.TopK,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=quant_config.is_block_quantized
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts(
out_dtype=td.layer.orig_dtype,
moe_config=moe_config,
quant_config=quant_config,
ep_rank=td.layer.moe_parallel_config.ep_rank,
ep_size=td.layer.moe_parallel_config.ep_size,
tp_rank=td.layer.moe_parallel_config.tp_rank,
tp_size=td.layer.moe_parallel_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)
......
......@@ -13,14 +13,19 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
......@@ -86,9 +91,28 @@ def test_flashinfer_fp4_moe_no_graph(
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
is_act_and_mul=is_gated_act,
routing_method=RoutingMethodType.TopK,
)
flashinfer_experts = FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
......
......@@ -36,6 +36,8 @@ from vllm.model_executor.layers.utils import shuffle_weight
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from .utils import make_dummy_moe_config
MNK = [
(1, 512, 384),
(1, 2880, 2880),
......@@ -174,9 +176,9 @@ def oai_triton_moe_impl(
)
if unfused:
fused_experts = UnfusedOAITritonExperts(quant_config)
fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config)
else:
fused_experts = OAITritonExperts(quant_config)
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
......
......@@ -18,7 +18,7 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
......@@ -332,7 +332,7 @@ def test_fused_moe(
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
def m_fused_moe(
a: torch.Tensor,
......@@ -437,7 +437,7 @@ def test_naive_block_assignment_moe(
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
def m_fused_moe(
a: torch.Tensor,
......
......@@ -4,7 +4,7 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_dummy_moe_config, make_test_weights
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
......@@ -92,8 +92,7 @@ def test_cutlass_fp4_moe_no_graph(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
......
......@@ -9,12 +9,18 @@ from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
......@@ -79,6 +85,8 @@ def pplx_cutlass_moe(
PplxPrepareAndFinalize,
)
init_workspace_manager(torch.cuda.current_device())
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
......@@ -132,28 +140,23 @@ def pplx_cutlass_moe(
num_dispatchers=num_dispatchers,
)
ab_strides1 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
ab_strides2 = torch.full(
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides1 = torch.full(
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides2 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
def make_moe_config() -> FusedMoEConfig:
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_dim,
num_local_experts=num_local_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=torch.bfloat16,
device="cuda",
routing_method=RoutingMethodType.Llama4,
)
experts = CutlassBatchedExpertsFp8(
num_local_experts,
num_dispatchers,
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fp8_w8a8_moe_quant_config(
moe_config=make_moe_config(),
quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
......@@ -162,6 +165,8 @@ def pplx_cutlass_moe(
if per_act_token
else a1_scale[rank],
),
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
fused_cutlass_experts = FusedMoEModularKernel(
......
......@@ -29,6 +29,7 @@ except ImportError:
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
make_dummy_moe_config,
make_shared_experts,
make_test_weights,
naive_batched_moe,
......@@ -584,6 +585,7 @@ def pplx_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
fused_experts = FusedMoEModularKernel(
......
......@@ -6,7 +6,6 @@ import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
......@@ -385,17 +384,11 @@ def test_grouped_topk(
global_num_experts,
)
routing_method_type = None
if scoring_func == "llama4":
routing_method_type = RoutingMethodType.Llama4
scoring_func = "sigmoid"
router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routing_method_type=routing_method_type,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
......
......@@ -10,6 +10,7 @@ equals N (not N // 2 like gated activations).
import pytest
import torch
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
......@@ -78,7 +79,10 @@ def test_triton_experts_no_mul_activation(
m, n, k, NUM_EXPERTS, topk
)
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
experts = TritonExperts(
moe_config=make_dummy_moe_config(),
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
)
ws1_shape, ws2_shape, out_shape = experts.workspace_shapes(
M=m,
......@@ -151,7 +155,10 @@ def test_workspace_shapes_no_mul_vs_gated():
M, N, K, topk = 64, 256, 128, 2
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
experts = TritonExperts(
moe_config=make_dummy_moe_config(),
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
)
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
......@@ -187,7 +194,10 @@ def test_adjust_n_for_activation():
"""Test the adjust_N_for_activation method."""
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
experts = TritonExperts(
moe_config=make_dummy_moe_config(),
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
)
N = 256
......
......@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts,
......@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
def make_dummy_moe_config(
num_experts: int = 1,
experts_per_token: int = 1,
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
do not actually use this config.
CUTLASSFp8 needs to set some params for workshapes.
"""
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
)
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
......@@ -81,6 +114,7 @@ def batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
)
......@@ -121,6 +155,7 @@ def naive_batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
)
......
......@@ -11,10 +11,11 @@ This ensures that 'pip install vllm' automatically installs the correct custom w
instead of allowing pip to download different versions from PyPI.
"""
import re
import sys
from pathlib import Path
import regex as re
def extract_version_from_wheel(wheel_name: str) -> str:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment