"vscode:/vscode.git/clone" did not exist on "d2af67441ddf5965aaebf129802a0a9d38f0e225"
Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark ...@@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -140,6 +144,12 @@ def bench_run( ...@@ -140,6 +144,12 @@ def bench_run(
a_fp8_scale: torch.Tensor, a_fp8_scale: torch.Tensor,
num_repeats: int, num_repeats: int,
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
for _ in range(num_repeats): for _ in range(num_repeats):
fused_experts( fused_experts(
a, a,
...@@ -147,10 +157,7 @@ def bench_run( ...@@ -147,10 +157,7 @@ def bench_run(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
) )
def run_cutlass_moe_fp4( def run_cutlass_moe_fp4(
...@@ -172,25 +179,27 @@ def bench_run( ...@@ -172,25 +179,27 @@ def bench_run(
device: torch.device, device: torch.device,
num_repeats: int, num_repeats: int,
): ):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
for _ in range(num_repeats): for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"): with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4( cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4, w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4, w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=m, m=m,
n=n, n=n,
k=k, k=k,
e=num_experts, e=num_experts,
device=device, quant_config=quant_config,
) )
def run_cutlass_from_graph( def run_cutlass_from_graph(
...@@ -211,26 +220,29 @@ def bench_run( ...@@ -211,26 +220,29 @@ def bench_run(
e: int, e: int,
device: torch.device, device: torch.device,
): ):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
return cutlass_moe_fp4( return cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4, w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4, w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=m, m=m,
n=n, n=n,
k=k, k=k,
e=num_experts, e=num_experts,
device=device, quant_config=quant_config,
) )
def run_triton_from_graph( def run_triton_from_graph(
...@@ -246,16 +258,18 @@ def bench_run( ...@@ -246,16 +258,18 @@ def bench_run(
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
return fused_experts( return fused_experts(
a, a,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
) )
def replay_graph(graph, num_repeats): def replay_graph(graph, num_repeats):
......
...@@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE ...@@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_experts,
...@@ -96,6 +97,11 @@ def bench_run( ...@@ -96,6 +97,11 @@ def bench_run(
a_scale: torch.Tensor, a_scale: torch.Tensor,
num_repeats: int, num_repeats: int,
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
for _ in range(num_repeats): for _ in range(num_repeats):
fused_experts( fused_experts(
a, a,
...@@ -103,10 +109,7 @@ def bench_run( ...@@ -103,10 +109,7 @@ def bench_run(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
) )
def run_cutlass_moe( def run_cutlass_moe(
...@@ -125,6 +128,12 @@ def bench_run( ...@@ -125,6 +128,12 @@ def bench_run(
per_act_token: bool, per_act_token: bool,
num_repeats: int, num_repeats: int,
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
for _ in range(num_repeats): for _ in range(num_repeats):
cutlass_moe_fp8( cutlass_moe_fp8(
a, a,
...@@ -132,14 +141,11 @@ def bench_run( ...@@ -132,14 +141,11 @@ def bench_run(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale,
w2_scale,
ab_strides1, ab_strides1,
ab_strides2, ab_strides2,
c_strides1, c_strides1,
c_strides2, c_strides2,
per_act_token, quant_config=quant_config,
a1_scale=None,
) )
def run_cutlass_from_graph( def run_cutlass_from_graph(
...@@ -156,6 +162,12 @@ def bench_run( ...@@ -156,6 +162,12 @@ def bench_run(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
...@@ -165,14 +177,11 @@ def bench_run( ...@@ -165,14 +177,11 @@ def bench_run(
w2_q, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale,
w2_scale,
ab_strides1, ab_strides1,
ab_strides2, ab_strides2,
c_strides1, c_strides1,
c_strides2, c_strides2,
per_act_token, quant_config=quant_config,
a1_scale=None,
) )
def run_triton_from_graph( def run_triton_from_graph(
...@@ -185,6 +194,11 @@ def bench_run( ...@@ -185,6 +194,11 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a_scale: torch.Tensor, a_scale: torch.Tensor,
): ):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
...@@ -194,10 +208,7 @@ def bench_run( ...@@ -194,10 +208,7 @@ def bench_run(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
) )
def replay_graph(graph, num_repeats): def replay_graph(graph, num_repeats):
......
...@@ -14,6 +14,10 @@ import ray ...@@ -14,6 +14,10 @@ import ray
import torch import torch
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
...@@ -134,42 +138,35 @@ def benchmark_config( ...@@ -134,42 +138,35 @@ def benchmark_config(
def run(): def run():
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
with override_config(config): if use_fp8_w8a8:
if use_deep_gemm: quant_dtype = torch.float8_e4m3fn
topk_weights, topk_ids, token_expert_indices = fused_topk( elif use_int8_w8a16:
x, input_gating, topk, False quant_dtype = torch.int8
) else:
return fused_experts( quant_dtype = None
x,
w1, quant_config = FusedMoEQuantConfig.make(
w2, quant_dtype=quant_dtype,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_quant_shape, block_shape=block_quant_shape,
allow_deep_gemm=True,
) )
else:
fused_moe( with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
return fused_experts(
x, x,
w1, w1,
w2, w2,
input_gating, topk_weights,
topk, topk_ids,
renormalize=True,
inplace=True, inplace=True,
use_fp8_w8a8=use_fp8_w8a8, quant_config=quant_config,
use_int8_w8a16=use_int8_w8a16, allow_deep_gemm=use_deep_gemm,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -414,7 +411,7 @@ class BenchmarkWorker: ...@@ -414,7 +411,7 @@ class BenchmarkWorker:
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
...@@ -547,7 +544,7 @@ def save_configs( ...@@ -547,7 +544,7 @@ def save_configs(
block_quant_shape: list[int], block_quant_shape: list[int],
save_dir: str, save_dir: str,
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts, from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info) make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo from .parallel_utils import ProcessGroupInfo
...@@ -40,7 +40,7 @@ class Config: ...@@ -40,7 +40,7 @@ class Config:
E: int E: int
topks: Union[list[int], int] topks: Union[list[int], int]
dtype: torch.dtype dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig] quant_config: Optional[TestMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
...@@ -52,7 +52,7 @@ class Config: ...@@ -52,7 +52,7 @@ class Config:
def __post_init__(self): def __post_init__(self):
if self.quant_config is None: if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig() self.quant_config = TestMoEQuantConfig(None, False, False, None)
def describe(self) -> str: def describe(self) -> str:
s = "" s = ""
...@@ -275,21 +275,19 @@ class WeightTensors: ...@@ -275,21 +275,19 @@ class WeightTensors:
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self): def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device()) device = torch.cuda.current_device()
self.w2 = self.w2.to(device=torch.cuda.current_device()) self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device)
if self.is_quantized(): if self.w1_scale is not None:
assert self.w1_scale is not None self.w1_scale = self.w1_scale.to(device=device)
assert self.w2_scale is not None if self.w2_scale is not None:
self.w1_scale = self.w1_scale.to( self.w2_scale = self.w2_scale.to(device=device)
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_gs is not None: if self.w1_gs is not None:
assert self.w2_gs is not None self.w1_gs = self.w1_gs.to(device=device)
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int, def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors": num_local_experts: int) -> "WeightTensors":
...@@ -297,20 +295,12 @@ class WeightTensors: ...@@ -297,20 +295,12 @@ class WeightTensors:
e = s + num_local_experts e = s + num_local_experts
w1 = self.w1[s:e, :, :] w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :] w2 = self.w2[s:e, :, :]
w1_scale = self.w1_scale[
w1_scale, w2_scale = (None, None) s:e, :, :] if self.w1_scale is not None else None
if self.is_quantized(): w2_scale = self.w2_scale[
assert self.w1_scale is not None s:e, :, :] if self.w2_scale is not None else None
assert self.w2_scale is not None w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w1_scale = self.w1_scale[s:e, :, :] w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
w2_scale = self.w2_scale[s:e, :, :]
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
...@@ -323,7 +313,8 @@ class WeightTensors: ...@@ -323,7 +313,8 @@ class WeightTensors:
in_dtype=config.dtype, in_dtype=config.dtype,
quant_dtype=config.quant_dtype, quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape, block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant, per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
) )
return WeightTensors(w1=w1, return WeightTensors(w1=w1,
w2=w2, w2=w2,
...@@ -342,8 +333,6 @@ class RankTensors: ...@@ -342,8 +333,6 @@ class RankTensors:
topk_ids: torch.Tensor topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor] expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self): def describe(self):
s = "" s = ""
s += "== Rank Tensors: \n" s += "== Rank Tensors: \n"
...@@ -426,7 +415,6 @@ class RankTensors: ...@@ -426,7 +415,6 @@ class RankTensors:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
expert_map=expert_map, expert_map=expert_map,
quant_config=config.quant_config,
) )
...@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors, ...@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
and config.supports_apply_weight_on_input()) and config.supports_apply_weight_on_input())
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_modular_kernel( def make_modular_kernel(
config: Config, config: Config,
vllm_config: VllmConfig, vllm_config: VllmConfig,
weights: WeightTensors, quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
def next_power_of_2(x): def next_power_of_2(x):
...@@ -548,20 +542,20 @@ def make_modular_kernel( ...@@ -548,20 +542,20 @@ def make_modular_kernel(
num_local_experts=config.num_local_experts, num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config, moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype, in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M), max_num_tokens=next_power_of_2(config.M),
) )
# make modular kernel # make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe) config.all2all_backend(), moe,
quant_config)
fused_experts = make_fused_experts( fused_experts = make_fused_experts(
config.fused_experts_type, config.fused_experts_type,
moe, moe,
quant_config,
prepare_finalize.num_dispatchers(), prepare_finalize.num_dispatchers(),
weights.w1_gs, config.N,
weights.w2_gs,
) )
modular_kernel = mk.FusedMoEModularKernel( modular_kernel = mk.FusedMoEModularKernel(
...@@ -583,12 +577,38 @@ def run_modular_kernel( ...@@ -583,12 +577,38 @@ def run_modular_kernel(
# weights for rank # weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config, weights) if config.quant_dtype == "nvfp4":
gscale = _make_gscale(config.num_local_experts)
else:
gscale = None
quant_config = FusedMoEQuantConfig.make(
config.quant_dtype,
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_act_token_quant,
per_out_ch_quant=config.is_per_out_ch_quant,
)
mk = make_modular_kernel(config, vllm_config, quant_config)
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = { mk_kwargs = {
"hidden_states": "hidden_states":
rank_tensors.hidden_states.clone( hidden_states,
), # impls might update the tensor in place
"w1": "w1":
rank_weights.w1, rank_weights.w1,
"w2": "w2":
...@@ -596,15 +616,9 @@ def run_modular_kernel( ...@@ -596,15 +616,9 @@ def run_modular_kernel(
"topk_weights": "topk_weights":
rank_tensors.topk_weights, rank_tensors.topk_weights,
"topk_ids": "topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), topk_ids,
"expert_map": "expert_map":
rank_tensors.expert_map, rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts": "global_num_experts":
config.E, config.E,
"apply_router_weight_on_input": "apply_router_weight_on_input":
......
...@@ -10,7 +10,8 @@ import torch ...@@ -10,7 +10,8 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
...@@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str): ...@@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
quant_config_dict = config_dict['quant_config'] quant_config_dict = config_dict['quant_config']
del config_dict['quant_config'] del config_dict['quant_config']
if quant_config_dict is None: if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None) quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config) quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict config_dict |= quant_config_dict
......
...@@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported ...@@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@dataclass
class TestMoEQuantConfig:
quant_dtype: Union[torch.dtype, str, None]
per_out_ch_quant: bool
per_act_token_quant: bool
block_shape: Optional[list[int]]
@dataclass @dataclass
class PrepareFinalizeInfo: class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat activation_format: mk.FusedMoEActivationFormat
...@@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [ ...@@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
] ]
common_float_and_int_types = common_float_types + [torch.int8] common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"] nvfp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn] fp8_types = [torch.float8_e4m3fn]
...@@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe() ...@@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_prepare_and_finalize( register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize, FlashInferCutlassMoEPrepareAndFinalize,
standard_format, standard_format,
nv_fp4_types, nvfp4_types,
blocked_quantization_support=True, blocked_quantization_support=True,
backend=None, backend=None,
force_multigpu=True, force_multigpu=True,
...@@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe() ...@@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_experts( register_experts(
FlashInferExperts, FlashInferExperts,
standard_format, standard_format,
nv_fp4_types, nvfp4_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True, supports_chunking=True,
# Note: this is a hack to get it to run for now # Note: this is a hack to get it to run for now
...@@ -306,36 +314,36 @@ if cutlass_fp4_supported(): ...@@ -306,36 +314,36 @@ if cutlass_fp4_supported():
register_experts( register_experts(
CutlassExpertsFp4, CutlassExpertsFp4,
standard_format, standard_format,
nv_fp4_types, nvfp4_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True, supports_chunking=True,
supports_expert_map=False, supports_expert_map=False,
) )
MK_QUANT_CONFIGS = [ MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None, None,
# per-channel / per-column weights and per-tensor activations # per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True, per_out_ch_quant=True,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None), block_shape=None),
# per-channel / per-column weights and per-token activations # per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True, per_out_ch_quant=True,
per_act_token_quant=True, per_act_token_quant=True,
block_shape=None), block_shape=None),
# per-tensor weights and per-tensor activations # per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False, per_out_ch_quant=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None), block_shape=None),
# per-tensor weights and per-token activations # per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False, per_out_ch_quant=False,
per_act_token_quant=True, per_act_token_quant=True,
block_shape=None), block_shape=None),
# block-quantized weights and 128 block per-token activations # block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False, per_out_ch_quant=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=[128, 128]), block_shape=[128, 128]),
...@@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [ ...@@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [ MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4", TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False, per_out_ch_quant=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None), block_shape=None),
] ]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize( def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str], backend: Optional[str],
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize: ) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None: if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
assert prepare_finalize is not None assert prepare_finalize is not None
return prepare_finalize return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize( return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1, use_dp=moe.moe_parallel_config.dp_size > 1)
a1_gscale=_make_gscale(moe.num_local_experts),
)
else: else:
return MoEPrepareAndFinalizeNoEP() return MoEPrepareAndFinalizeNoEP()
...@@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: ...@@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
return t[s:e] return t[s:e]
def make_cutlass_strides(
e: int,
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
def make_fused_experts( def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int, num_dispatchers: int,
w1_gs: Optional[torch.Tensor], N: int,
w2_gs: Optional[torch.Tensor],
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = { batch_kwargs = {
"max_num_tokens": moe.max_num_tokens, "max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers, "num_dispatchers": num_dispatchers,
} }
quant_kwargs = { quant_kwargs = {
"use_fp8_w8a8": use_fp8, "quant_config": quant_config,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
} }
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
if fused_experts_type == BatchedDeepGemmExperts: if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | { kwargs = batch_kwargs | quant_kwargs
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...") print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs) experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts: elif fused_experts_type == BatchedTritonExperts:
...@@ -422,8 +429,8 @@ def make_fused_experts( ...@@ -422,8 +429,8 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs) experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts: elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...") print("Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts() experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts: elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...") print(f"Making TritonExperts {kwargs} ...")
...@@ -437,62 +444,50 @@ def make_fused_experts( ...@@ -437,62 +444,50 @@ def make_fused_experts(
print(f"Making NaiveBatchedExperts {kwargs} ...") print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs) experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8: elif fused_experts_type == CutlassExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = { kwargs = {
"out_dtype": moe.in_dtype, "out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant, "ab_strides1": strides[0],
"per_out_ch_quant": moe.per_out_ch_quant, "ab_strides2": strides[1],
"block_shape": moe.block_shape, "c_strides1": strides[2],
} "c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassExpertsFp8 {kwargs} ...") print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs) experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8: elif fused_experts_type == CutlassBatchedExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = { kwargs = {
"max_experts_per_worker": moe.num_local_experts, "max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers, "num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype, "out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant, "ab_strides1": strides[0],
"per_out_ch_quant": moe.per_out_ch_quant, "ab_strides2": strides[1],
"block_shape": moe.block_shape, "c_strides1": strides[2],
} "c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs) experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4: elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = { kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), "max_experts_per_worker": moe.num_local_experts,
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"num_dispatchers": num_dispatchers, "num_dispatchers": num_dispatchers,
} "out_dtype": moe.in_dtype,
} | quant_kwargs
print(f"Making CutlassExpertsFp4 {kwargs} ...") print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs) experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts: elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = { kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype, "out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank, "ep_rank": moe.ep_rank,
"ep_size": moe.ep_size, "ep_size": moe.ep_size,
"tp_rank": moe.tp_rank, "tp_rank": moe.tp_rank,
"tp_size": moe.tp_size, "tp_size": moe.tp_size,
} } | quant_kwargs
print(f"Making FlashInferExperts {kwargs} ...") print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs) experts = FlashInferExperts(**kwargs)
else: else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
return experts return experts
...@@ -6,6 +6,8 @@ import torch ...@@ -6,6 +6,8 @@ import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts) BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
...@@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, ...@@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
rank=0, rank=0,
) )
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
)
# triton (reference) # triton (reference)
triton_experts = BatchedTritonExperts( triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
use_fp8_w8a8=True, quant_config=quant_config,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
) )
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
...@@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, ...@@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E, global_num_experts=E,
) )
...@@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, ...@@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
deepgemm_experts = BatchedDeepGemmExperts( deepgemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
block_shape=BLOCK_SIZE, quant_config=quant_config,
per_act_token_quant=False,
) )
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
...@@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, ...@@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E, global_num_experts=E,
) )
......
...@@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype, in_dtype=act_dtype,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
block_shape=block_shape, block_shape=block_shape,
per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_act_token_quant,
) )
out_shape = (num_experts, max_tokens_per_expert, N) out_shape = (num_experts, max_tokens_per_expert, N)
...@@ -250,7 +250,7 @@ def test_fused_moe_batched_experts( ...@@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
block_shape=block_shape, block_shape=block_shape,
in_dtype=act_dtype, in_dtype=act_dtype,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_act_token_quant,
) )
if input_scales and quant_dtype is not None: if input_scales and quant_dtype is not None:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul) native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
...@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s, w1, w2, quant_config = make_test_quant_config(
_) = make_test_weights(E, E,
N, N,
K, K,
dtype, dtype,
torch.float8_e4m3fn, quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=block_size) block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, m_fused_moe = modular_triton_fused_moe(quant_config)
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=block_size)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
...@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a, a,
w1, w1,
w2, w2,
w1_s, quant_config.w1_scale,
w2_s, quant_config.w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
block_size, block_size,
) )
out = fused_experts( out = fused_experts(a,
a,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
use_fp8_w8a8=True, quant_config=quant_config)
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
m_out = m_fused_moe( m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0] # 0.039 only needed for M >= 8192
tol = 0.035 if M < 40000 else 0.039 tol = 0.035 if M < 8192 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol) torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol) torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
...@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, ...@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s, (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
_) = make_test_weights(E, E,
N, N,
K, K,
dtype, dtype,
torch.float8_e4m3fn, torch.float8_e4m3fn,
per_act_token_quant=False, per_out_ch_quant=False,
block_shape=block_size) block_shape=block_size,
)
# Note: for now use_compile will error out if the problem size is # Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and # large enough to trigger chunking. I'm leaving the flag and
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quant_utils import (native_per_token_group_quant_int8, from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul) native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0): if current_platform.get_device_capability() < (7, 0):
...@@ -50,7 +50,7 @@ MNK_FACTORS = [ ...@@ -50,7 +50,7 @@ MNK_FACTORS = [
(2048, 128, 128), (2048, 128, 128),
(2048, 1024, 7168), (2048, 1024, 7168),
(2048, 4096, 512), (2048, 4096, 512),
(2048, 4096, 7168), (2048, 4096, 4096),
] ]
E = [8, 24] E = [8, 24]
...@@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
(_, w1, w1_s, _), (_, w2, w2_s, w1, w2, quant_config = make_test_quant_config(
_) = make_test_weights(E, E,
N, N,
K, K,
dtype, dtype,
torch.int8, quant_dtype=torch.int8,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=block_size) block_shape=block_size,
)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_experts(a,
a,
w1, w1,
w2, w2,
score, topk_weights,
topk, topk_ids,
renormalize=False, quant_config=quant_config)
use_int8_w8a8=True, ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
w1_scale=w1_s, quant_config.w2_scale, score, topk,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size) block_size)
# Check results # Check results
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses import dataclasses
from math import prod from math import prod
from typing import Optional from typing import Optional
...@@ -9,6 +10,8 @@ import torch ...@@ -9,6 +10,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8) cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
...@@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, ...@@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def slice_experts(): def slice_experts():
slice_params = [ slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2", "w1_scale", "w2_scale" "c_strides2"
] ]
full_tensors = { full_tensors = {
k: v k: v
...@@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, ...@@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
if k in slice_params and k in cutlass_moe_kwargs if k in slice_params and k in cutlass_moe_kwargs
} }
quant_config = cutlass_moe_kwargs["quant_config"]
for i in range(0, num_experts, num_local_experts): for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts s, e = i, i + num_local_experts
...@@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, ...@@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
for k, t in full_tensors.items(): for k, t in full_tensors.items():
cutlass_moe_kwargs[k] = t[s:e] cutlass_moe_kwargs[k] = t[s:e]
new_quant_config = copy.deepcopy(quant_config)
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
cutlass_moe_kwargs["quant_config"] = new_quant_config
yield cutlass_moe_kwargs yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
...@@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor: num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([ assert not any([
t is None for t in [ t is None for t in [
...@@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
] ]
]) ])
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
w2_scale=moe_tensors.w2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
# Set to moe_tensors.a_scale iff static scales + per tensor.
# This is not currently being tested.
a1_scale=None,
)
kwargs = { kwargs = {
'a': moe_tensors.a, 'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights, 'topk_weights': topk_weights,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1, 'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2, 'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1, 'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2, 'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token, 'quant_config': quant_config,
'a1_scale': None #moe_tensors.a_scale
} }
num_experts = moe_tensors.w1.size(0) num_experts = moe_tensors.w1.size(0)
...@@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids) quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
if ep_size is not None: if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly" assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size number_local_experts = e // ep_size
else: else:
number_local_experts = None number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
number_local_experts) per_out_ch, number_local_experts)
# Note 5.5 only needed for larger problem sizes, 5 works ok for # Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest. # the rest.
...@@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
topk_ids) triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream): with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token) per_act_token, per_out_ch)
torch.cuda.synchronize() torch.cuda.synchronize()
graph.replay() graph.replay()
......
...@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup ...@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
...@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights( ...@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
Return weights w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
(_, w1q, w1_scale, _), (_, w2q, w2_scale, (_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e, n, k, torch.bfloat16, _) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fn,
block_size) block_shape=block_size)
return w1q, w2q, w1_scale, w2_scale return w1q, w2q, w1_scale, w2_scale
...@@ -130,10 +135,11 @@ class TestTensors: ...@@ -130,10 +135,11 @@ class TestTensors:
config=config) config=config)
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, def make_ll_modular_kernel(
max_tokens_per_rank: int, dp_size: int, pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
hidden_size: int, q_dtype: Optional[torch.dtype], dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel: test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert test_config.low_latency assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None assert test_config.use_fp8_dispatch is not None
...@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
fused_experts = BatchedDeepGemmExperts( fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank, max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size, num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size, quant_config=quant_config,
per_act_token_quant=test_config.per_act_token_quant) )
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, def make_ht_modular_kernel(
dp_size: int, num_local_experts: int, pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
q_dtype: Optional[torch.dtype], num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel: test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert not test_config.low_latency assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None assert test_config.use_fp8_dispatch is None
...@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype, q_dtype=q_dtype,
block_shape=test_config.block_size) block_shape=test_config.block_size)
fused_experts = DeepGemmExperts() fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, def make_modular_kernel(
num_local_experts: int, pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
test_tensors: TestTensors) -> FusedMoEModularKernel: num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config test_config = test_tensors.config
...@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, ...@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
dp_size=dp_size, dp_size=dp_size,
hidden_size=hidden_size, hidden_size=hidden_size,
q_dtype=q_dtype, q_dtype=q_dtype,
test_config=test_config) test_config=test_config,
quant_config=quant_config)
else: else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, mk = make_ht_modular_kernel(pg,
q_dtype, test_config) pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
return mk return mk
...@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return expert_map.to(device=torch.cuda.current_device(), return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32) dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
# Make modular kernel # Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel( mk: FusedMoEModularKernel = make_modular_kernel(
pg=pg, pg=pg,
pgi=pgi, pgi=pgi,
dp_size=dp_size, dp_size=dp_size,
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
test_tensors=test_tensors) test_tensors=test_tensors,
quant_config=quant_config)
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)
out = mk.forward(hidden_states=test_tensors.rank_tokens, out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1, w1=w1,
...@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
activation="silu", activation="silu",
global_num_experts=num_experts, global_num_experts=num_experts,
expert_map=build_expert_map(), expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False) apply_router_weight_on_input=False)
return out return out
...@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, ...@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]): a1_scale: torch.Tensor, block_shape: list[int]):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
)
return fused_experts( return fused_experts(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
...@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, ...@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
# Make sure this is set to False so we # Make sure this is set to False so we
# don't end up comparing the same implementation. # don't end up comparing the same implementation.
allow_deep_gemm=False) allow_deep_gemm=False)
......
...@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops ...@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts) BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
...@@ -129,11 +130,9 @@ def make_modular_kernel( ...@@ -129,11 +130,9 @@ def make_modular_kernel(
num_local_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype], q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
per_act_token_quant: bool, quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel: ) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None ll_args: Optional[DeepEPLLArgs] = None
...@@ -159,24 +158,14 @@ def make_modular_kernel( ...@@ -159,24 +158,14 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size num_dispatchers = pgi.world_size // dp_size
if low_latency_mode: if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode" assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts( fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK, max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
use_fp8_w8a8=is_quantized, quant_config=quant_config,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
) )
else: else:
fused_experts = TritonExperts( fused_experts = TritonExperts(quant_config=quant_config)
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
...@@ -217,11 +206,6 @@ def deep_ep_moe_impl( ...@@ -217,11 +206,6 @@ def deep_ep_moe_impl(
if is_quantized: if is_quantized:
q_dtype = torch.float8_e4m3fn q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens) out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0) total_num_tokens = test_tensors.rank_tokens.size(0)
...@@ -236,6 +220,19 @@ def deep_ep_moe_impl( ...@@ -236,6 +220,19 @@ def deep_ep_moe_impl(
rank_token_scales_chunk = rank_token_scales_chunk[ rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end] chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token_quant,
a1_scale=rank_token_scales_chunk,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
out = mk.forward(hidden_states=rank_tokens_chunk, out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1, w1=w1,
w2=w2, w2=w2,
...@@ -245,12 +242,6 @@ def deep_ep_moe_impl( ...@@ -245,12 +242,6 @@ def deep_ep_moe_impl(
activation="silu", activation="silu",
global_num_experts=num_experts, global_num_experts=num_experts,
expert_map=build_expert_map(), expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False) apply_router_weight_on_input=False)
if not skip_result_store: if not skip_result_store:
...@@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] ...@@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
...@@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] ...@@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@requires_deep_ep @requires_deep_ep
def test_deep_ep_moe( def test_deep_ep_moe(
dtype: torch.dtype, dtype: torch.dtype,
mnk: tuple[int, int, int], m: int,
n: int,
k: int,
num_experts: int, num_experts: int,
topk: int, topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
...@@ -424,7 +417,6 @@ def test_deep_ep_moe( ...@@ -424,7 +417,6 @@ def test_deep_ep_moe(
): ):
low_latency_mode = False low_latency_mode = False
use_fp8_dispatch = False use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7) current_platform.seed_everything(7)
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
...@@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False] ...@@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], def test_low_latency_deep_ep_moe(
num_experts: int, topk: int, dtype: torch.dtype,
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_fp8_dispatch: bool): use_fp8_dispatch: bool,
):
low_latency_mode = True low_latency_mode = True
m, n, k = mnk
if (low_latency_mode if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
......
...@@ -11,6 +11,8 @@ import math ...@@ -11,6 +11,8 @@ import math
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
# vLLM fused-expert reference (Triton fallback + DeepGEMM option) # vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
)
# triton reference # triton reference
out_triton = fused_experts( out_triton = fused_experts(
hidden_states=tokens_bf16, hidden_states=tokens_bf16,
...@@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
allow_deep_gemm=False, allow_deep_gemm=False,
) )
...@@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
use_fp8_w8a8=True, quant_config=quant_config,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
allow_deep_gemm=True, allow_deep_gemm=True,
) )
diff = calc_diff(out_deepgemm, out_triton) diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}" assert diff < 0.001, f"Diff exceeded 1%: {diff}"
# Note: W1 has shape (E, 2N, K), so N = 512 # Note: N <= 512 will disable the deepgemm path due to performance issues.
# can trigger the deepgemm path.
MNKs = [ MNKs = [
(1024, 768, 128), (1024, 768, 128),
(1024, 768, 512), (1024, 768, 512),
...@@ -144,15 +144,15 @@ TOPKS = [2, 6] ...@@ -144,15 +144,15 @@ TOPKS = [2, 6]
NUM_EXPERTS = [32] NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(), @pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels") reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as m: with monkeypatch.context() as mp:
m.setenv("VLLM_USE_DEEP_GEMM", "1") mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module( _fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe") "vllm.model_executor.layers.fused_moe.fused_moe")
...@@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): ...@@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8) _spy_deep_gemm_moe_fp8)
m, n, k = mnk
if topk > num_experts: if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}") pytest.skip(f"topk={topk} > num_experts={num_experts}")
......
...@@ -6,6 +6,8 @@ import pytest ...@@ -6,6 +6,8 @@ import pytest
import torch import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
...@@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function, custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax") scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts( output = fused_experts(
td.hidden_states, td.hidden_states,
td.w13_quantized, td.w13_quantized,
...@@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
activation="silu", activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
quant_config=quant_config,
) )
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
...@@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function, custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax") scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts( output = fused_experts(
td.hidden_states, td.hidden_states,
td.w13_quantized, td.w13_quantized,
...@@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
activation="silu", activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
quant_config=quant_config,
) )
td.layer.dp_size = 1 td.layer.dp_size = 1
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype) dequantize_nvfp4_to_dtype)
...@@ -41,7 +41,6 @@ MNK_FACTORS = [ ...@@ -41,7 +41,6 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
...@@ -56,14 +55,13 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -56,14 +55,13 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
quant_blocksize = 16 quant_blocksize = 16
(_, w1_q, w1_blockscale, w1_q, w2_q, quant_config = make_test_quant_config(
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e, e,
n, n,
k, k,
in_dtype=dtype, in_dtype=dtype,
quant_dtype="nvfp4", quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize? block_shape=None,
per_act_token_quant=False, per_act_token_quant=False,
) )
...@@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
topk, topk,
renormalize=False) renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel( flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
FlashInferExperts( FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
a1_gscale=a1_gs, )
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
flashinfer_output = flashinfer_experts( flashinfer_output = flashinfer_experts(
hidden_states=a, hidden_states=a,
w1=w1_q, w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q, w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
) )
...@@ -122,15 +102,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -122,15 +102,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e): for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_blockscale[idx], w1_q[idx],
w1_gs[idx], quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
dtype=dtype, dtype=dtype,
device=w1_q.device, device=w1_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_blockscale[idx], w2_q[idx],
w2_gs[idx], quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
dtype=dtype, dtype=dtype,
device=w2_q.device, device=w2_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
......
...@@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor ...@@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize) BatchedPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
...@@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): ...@@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
pc2, pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make(
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
)
out_triton_monolithic = triton_kernel_moe_forward( out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri, hidden_states=x_tri,
w1=w1_tri, w1=w1_tri,
...@@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): ...@@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
gating_output=exp_data_tri, gating_output=exp_data_tri,
topk=topk, topk=topk,
renormalize=True, renormalize=True,
w1_bias=w1_bias_tri, quant_config=quant_config,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
) )
out_triton_monolithic = out_triton_monolithic[..., :K] out_triton_monolithic = out_triton_monolithic[..., :K]
...@@ -336,6 +341,13 @@ def batched_moe( ...@@ -336,6 +341,13 @@ def batched_moe(
) -> torch.Tensor: ) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
w1_precision=w1_precision,
w2_precision=w2_precision,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize( BatchedPrepareAndFinalize(
max_num_tokens, max_num_tokens,
...@@ -344,19 +356,12 @@ def batched_moe( ...@@ -344,19 +356,12 @@ def batched_moe(
rank=0, rank=0,
), ),
BatchedOAITritonExperts( BatchedOAITritonExperts(
None,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
w1_precision=w1_precision, quant_config=quant_config,
w2_precision=w2_precision,
), ),
) )
extra_expert_args = {
"w1_bias": w1_bias,
"w2_bias": w2_bias,
}
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
return fused_experts( return fused_experts(
...@@ -365,7 +370,6 @@ def batched_moe( ...@@ -365,7 +370,6 @@ def batched_moe(
w2, w2,
topk_weight, topk_weight,
topk_ids, topk_ids,
extra_expert_args=extra_expert_args,
) )
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
...@@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, ...@@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
run_modular_kernel) run_modular_kernel)
from .modular_kernel_tools.mk_objects import ( from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config) parallel_launch_with_config)
...@@ -55,7 +55,7 @@ def rank_worker( ...@@ -55,7 +55,7 @@ def rank_worker(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
vllm_config: VllmConfig, vllm_config: VllmConfig,
cpu_group, cpu_group,
config: Config, base_config: Config,
weights: WeightTensors, weights: WeightTensors,
verbose: bool, verbose: bool,
): ):
...@@ -63,42 +63,44 @@ def rank_worker( ...@@ -63,42 +63,44 @@ def rank_worker(
# sanity check # sanity check
from vllm import envs from vllm import envs
if config.fused_moe_chunk_size is not None: if base_config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) assert (
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
Ms = config.Ms Ms = base_config.Ms
assert isinstance(Ms, list) assert isinstance(Ms, list)
TOPKs = config.topks TOPKs = base_config.topks
assert isinstance(TOPKs, list) assert isinstance(TOPKs, list)
exceptions = [] exceptions = []
count = 0 count = 0
for m, topk in product(Ms, TOPKs): for m, topk in product(Ms, TOPKs):
# override m and topk
config = copy.deepcopy(base_config)
config.Ms = m
config.topks = topk
try: try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
count = count + 1 count = count + 1
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank # inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi) rank_tensors = RankTensors.make(config, pgi)
# modular kernel out # modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
rank_tensors) rank_tensors)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors) ref_out = reference_moe_impl(config, weights, rank_tensors)
if config.quant_dtype == "nvfp4": if config.quant_dtype == "nvfp4":
atol = 1e-1 atol = 1e-1 if config.K < 4096 else 2e-1
rtol = 1e-1 rtol = 1e-1 if config.K < 4096 else 2e-1
else: else:
atol = 3e-2 atol = 3e-2
rtol = 3e-2 rtol = 3e-2
...@@ -132,7 +134,7 @@ Ms = [32, 64] ...@@ -132,7 +134,7 @@ Ms = [32, 64]
# hidden sizes, making this too large will cause fp4 tests to fail. # hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm. # Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048] Ks = [2048]
Ns = [2048] Ns = [1024]
TOPKs = [4, 1] TOPKs = [4, 1]
Es = [32] Es = [32]
DTYPEs = [torch.bfloat16] DTYPEs = [torch.bfloat16]
...@@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool: ...@@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
@meets_multi_gpu_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig], quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
...@@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu( ...@@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu( def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig], quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
......
...@@ -15,11 +15,14 @@ from transformers import MixtralConfig ...@@ -15,11 +15,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe) fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
...@@ -187,14 +190,9 @@ def test_fused_moe( ...@@ -187,14 +190,9 @@ def test_fused_moe(
# #
# Setup test functions # Setup test functions
# #
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, m_fused_moe_fn = modular_triton_fused_moe(quant_config)
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=None)
def m_fused_moe( def m_fused_moe(
a: torch.Tensor, a: torch.Tensor,
...@@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else: else:
e_map = None e_map = None
if weight_bits == 4:
quant_config_builder = int4_w4a16_moe_quant_config
else:
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a, triton_output = fused_moe(a,
w1_qweight, w1_qweight,
...@@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
score, score,
topk, topk,
renormalize=False, renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
w1_scale=w1_scales, quant_config=quant_config)
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, torch_output = torch_moe(a,
w1_ref, w1_ref,
w2_ref, w2_ref,
......
...@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, ...@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
from tests.kernels.utils import torch_moe from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
in_dtype=dtype, in_dtype=dtype,
quant_dtype="nvfp4", quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize? block_shape=None, # use quant_blocksize?
per_act_token_quant=False, per_out_ch_quant=False,
) )
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
...@@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
assert w1_blockscale is not None assert w1_blockscale is not None
assert w2_blockscale is not None assert w2_blockscale is not None
quant_config = nvfp4_moe_quant_config(
g1_alphas=(1 / w1_gs),
g2_alphas=(1 / w2_gs),
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)
cutlass_output = cutlass_moe_fp4( cutlass_output = cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q, w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q, w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
g2_alphas=(1 / w2_gs),
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=quant_config,
m=m, m=m,
n=n, n=n,
k=k, k=k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment