Unverified Commit 5dcd7ef1 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

parent ffc0a279
......@@ -1422,3 +1422,10 @@ steps:
num_gpus: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor/config-b200.txt
- label: MoE Refactor Integration Test (B200 DP - TEMPORARY) # optional
gpu: b200
optional: true
num_gpus: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
......@@ -6,13 +6,16 @@ kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
but use different quantization strategies and backends.
"""
import nvtx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
......@@ -59,6 +62,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
init_workspace_manager(torch.cuda.current_device())
(m, k, n) = mkn
dtype = torch.half
......@@ -121,24 +125,7 @@ def bench_run(
# Force per-tensor quantization for all cases
per_act_token = False
# Create stride tensors for CUTLASS
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
......@@ -148,66 +135,16 @@ def bench_run(
per_out_ch_quant=per_out_ch,
)
for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
e=num_experts,
n=n,
k=k,
quant_config=quant_config,
)
def run_cutlass_moe_fp8(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
cutlass_moe_fp8(
a=a,
w1_q=w1,
w2_q=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
activation="silu",
global_num_experts=num_experts,
)
# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
device=w1.device,
),
)
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
......@@ -216,17 +153,12 @@ def bench_run(
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
# Capture 10 invocations like benchmark_moe.py
for _ in range(10):
cutlass_moe_fp8(
a=a,
w1_q=w1_fp8q_cutlass,
w2_q=w2_fp8q_cutlass,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
fn(
a,
w1_fp8q_cutlass,
w2_fp8q_cutlass,
topk_weights,
topk_ids,
activation="silu",
global_num_experts=num_experts,
)
......
......@@ -5,14 +5,18 @@ import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
......@@ -45,6 +49,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
init_workspace_manager(torch.cuda.current_device())
label = "Quant Matmul"
sub_label = (
......@@ -82,11 +87,6 @@ def bench_run(
a, score, topk, renormalize=False
)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
......@@ -120,10 +120,6 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
......@@ -135,31 +131,29 @@ def bench_run(
per_act_token_quant=per_act_token,
)
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
device=w1.device,
),
)
for _ in range(num_repeats):
fn(a, w1, w2, topk_weights, topk_ids)
def run_cutlass_from_graph(
a: torch.Tensor,
a_scale: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
......@@ -169,21 +163,23 @@ def bench_run(
per_act_token_quant=per_act_token,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
device=w1.device,
),
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp8(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config=quant_config,
)
return fn(a, w1, w2, topk_weights, topk_ids)
def run_triton_from_graph(
a: torch.Tensor,
......@@ -227,10 +223,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
......@@ -268,10 +260,6 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
......@@ -330,10 +318,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
......@@ -342,7 +326,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
......
......@@ -48,8 +48,6 @@ def clear_triton_cache():
# Try to clear Triton's runtime cache
try:
import triton
if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
......
......@@ -87,7 +87,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],</br>[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
......
model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
VLLM_USE_DEEP_GEMM_E8M0: "0"
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
VLLM_USE_DEEP_GEMM_E8M0: "0"
model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass-fi-a2av.yaml
model_name: "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"
accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
......
......@@ -5,7 +5,7 @@ Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml
Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-marlin.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml
......
......@@ -61,6 +61,7 @@ def test_gsm8k_correctness(config_filename):
server_args.extend(
[
"--trust-remote-code",
"--disable-uvicorn-access-log",
]
)
......
......@@ -7,17 +7,22 @@ from math import prod
import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8,
CutlassExpertsFp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
......@@ -150,16 +155,15 @@ class MOETensors8Bit(MOETensors):
def run_with_expert_maps(
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
num_experts: int,
num_local_experts: int,
quant_config: FusedMoEQuantConfig,
**cutlass_moe_kwargs,
):
def slice_experts():
slice_params = [
"w1_q",
"w2_q",
"ab_strides1",
"ab_strides2",
"c_strides1",
"c_strides2",
"w1",
"w2",
]
full_tensors = {
k: v
......@@ -167,8 +171,6 @@ def run_with_expert_maps(
if k in slice_params and k in cutlass_moe_kwargs
}
quant_config = cutlass_moe_kwargs["quant_config"]
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
......@@ -187,13 +189,23 @@ def run_with_expert_maps(
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
cutlass_moe_kwargs["quant_config"] = new_quant_config
yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
for kwargs in slice_experts():
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
yield cutlass_moe_kwargs, new_quant_config
out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
for kwargs, new_quant_config in slice_experts():
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=kwargs["hidden_states"].dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=kwargs["w2"].shape[0], # type: ignore[union-attr]
n=kwargs["w2"].shape[2], # type: ignore[union-attr]
k=kwargs["w2"].shape[1], # type: ignore[union-attr]
quant_config=new_quant_config,
device="cuda",
),
)
out_tensor = out_tensor + kernel(**kwargs)
return out_tensor
......@@ -230,27 +242,35 @@ def run_8_bit(
)
kwargs = {
"a": moe_tensors.a,
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
"hidden_states": moe_tensors.a,
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"ab_strides1": moe_tensors.ab_strides1,
"ab_strides2": moe_tensors.ab_strides2,
"c_strides1": moe_tensors.c_strides1,
"c_strides2": moe_tensors.c_strides2,
"quant_config": quant_config,
}
num_experts = moe_tensors.w1.size(0)
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
return cutlass_moe_fp8(**kwargs)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=moe_tensors.a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
quant_config=quant_config,
device="cuda",
),
)
return kernel(**kwargs)
assert num_local_experts is not None
return run_with_expert_maps(
num_experts,
num_local_experts, # type: ignore[arg-type]
quant_config,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment