Unverified Commit c7d8724e authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)


Signed-off-by: default avatarshuw <shuw@nvidia.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent b38baabc
...@@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, ...@@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
c_strides, per_act_token, per_out_ch) c_strides, per_act_token, per_out_ch)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor,
alphas: torch.Tensor, problem_sizes: torch.Tensor, b_scales: torch.Tensor, alphas: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, problem_sizes: torch.Tensor,
out_dtype: torch.dtype, device: torch.device): expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
""" """
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes. the gemms for each combination based on the specified problem sizes.
...@@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, ...@@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation. MMs used in the fused MoE operation.
""" """
m_topk = a_tensors.shape[0] return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
n = b_tensors.shape[1] a_scales, b_scales, alphas,
c_shape = (m_topk, n) problem_sizes, expert_offsets,
c = torch.empty(c_shape, device=device, dtype=out_dtype) sf_offsets)
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
b_scales, alphas, problem_sizes,
expert_offsets, sf_offsets)
return c.to(out_dtype)
# aqlm # aqlm
......
...@@ -119,6 +119,7 @@ if TYPE_CHECKING: ...@@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
...@@ -853,6 +854,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -853,6 +854,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
# Control the cache sized used by the xgrammar compiler. The default # Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas. # of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason. # It can be changed with this variable if needed for some reason.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -255,28 +255,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -255,28 +255,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
output: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): extra_expert_args: Optional[dict[str, Any]]):
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -142,7 +142,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -142,7 +142,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool): apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
experts = (self.batched_deep_gemm_experts experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None assert experts is not None
...@@ -150,4 +151,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -150,4 +151,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, global_num_experts, expert_map, w1_scale, activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta, workspace2, expert_tokens_meta,
apply_router_weight_on_input) apply_router_weight_on_input, extra_expert_args)
...@@ -15,6 +15,7 @@ from vllm.logger import init_logger ...@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -188,6 +189,11 @@ class FusedMoEParallelConfig: ...@@ -188,6 +189,11 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE
and has_flashinfer_cutlass_fused_moe())
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
...@@ -392,6 +398,10 @@ class FusedMoEConfig: ...@@ -392,6 +398,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
@staticmethod @staticmethod
def make( def make(
num_experts: int, num_experts: int,
...@@ -435,6 +445,12 @@ class FusedMoEConfig: ...@@ -435,6 +445,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config): if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8
if weight_quant is not None: if weight_quant is not None:
per_out_ch_quant = ( per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL) weight_quant.strategy == QuantizationStrategy.CHANNEL)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels.""" """ CUTLASS based Fused MoE kernels."""
from typing import Callable, Optional from typing import Any, Callable, Optional
import torch import torch
...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( ...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache) _resize_cache,
extract_required_args)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool): apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
...@@ -431,7 +433,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() ...@@ -431,7 +433,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cutlass_moe_fp4(a: torch.Tensor, def run_cutlass_moe_fp4(
output: torch.Tensor,
a: torch.Tensor,
a1_gscale: torch.Tensor, a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor, w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor, w1_blockscale: torch.Tensor,
...@@ -442,12 +446,15 @@ def cutlass_moe_fp4(a: torch.Tensor, ...@@ -442,12 +446,15 @@ def cutlass_moe_fp4(a: torch.Tensor,
w2_alphas: torch.Tensor, w2_alphas: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int, m: int,
n: int, n: int,
k: int, k: int,
e: int, e: int,
device: torch.device, device: torch.device,
apply_router_weight_on_input: bool = False): apply_router_weight_on_input: bool = False,
) -> None:
""" """
MoE implementation for FP4 Inputs MoE implementation for FP4 Inputs
...@@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor, ...@@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor,
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.") " between weights.")
assert (k_a // 2 == half_k_w1 assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2") and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
"expected `n`") "expected `n`")
assert (m == m_a), "input shape mismatch" assert (m == m_a), "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.size(0) == m and topk_ids.size(0) assert (topk_weights.size(0) == m and topk_ids.size(0)
== m), ("topk must be provided for each row of a") == m), ("topk must be provided for each row of a")
topk = topk_ids.size(1)
out_dtype = a.dtype out_dtype = a.dtype
num_topk = topk_ids.size(1) num_topk = topk_ids.size(1)
...@@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor, ...@@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor,
blockscale_offsets) blockscale_offsets)
a = ops.shuffle_rows(a, a_map) a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a, a,
a1_gscale, a1_gscale,
...@@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor, ...@@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor,
blockscale_offsets, blockscale_offsets,
num_topk, num_topk,
) )
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1, w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1], expert_offsets[:-1], blockscale_offsets[:-1])
out_dtype, device)
del rep_a_fp4, rep_a_blockscale del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor. torch.ops._C.silu_and_mul(c2, c1)
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
device=device,
dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1], w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device) blockscale_offsets[:-1])
del int_fp4, int_blockscale del int_fp4, int_blockscale
c2 = ops.shuffle_rows(c2, c_map) c3 = ops.shuffle_rows(c3, c_map)
assert output.dtype == out_dtype
if not apply_router_weight_on_input: if not apply_router_weight_on_input:
out = (c2.view(m, num_topk, k) * output.copy_(
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1) (c3.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
non_blocking=True)
else:
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
return
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else: else:
out = c2.view(m, num_topk, k).sum(dim=1) workspace1 = (M * topk, max(2 * N, K))
return out.to(dtype=out_dtype) workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
required_keys = [
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
"e", "device"
]
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
device) = extract_required_args(extra_expert_args, required_keys)
run_cutlass_moe_fp4(
output=output,
a=hidden_states,
a1_gscale=a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_alphas=g1_alphas,
a2_gscale=a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_alphas=g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace13=workspace13,
workspace2=workspace2,
m=m,
n=n,
k=k,
e=e,
device=device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
use_batched_format=False,
),
)
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def _valid_cutlass_block_scaled_grouped_gemm( def _valid_cutlass_block_scaled_grouped_gemm(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -152,6 +152,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -152,6 +152,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
assert self.block_shape is not None assert self.block_shape is not None
assert a1q_scale is not None assert a1q_scale is not None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import deep_ep import deep_ep
import torch import torch
...@@ -127,16 +127,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -127,16 +127,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights) expert_topk_weights)
def prepare( def prepare(
self, self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a1: torch.Tensor, a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int,
a2_scale: Optional[torch.Tensor], expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -191,7 +187,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -191,7 +187,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert self.handle is not None assert self.handle is not None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Any, Optional, Union
import deep_ep import deep_ep
import torch import torch
...@@ -111,16 +111,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -111,16 +111,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales return x, x_scales
def prepare( def prepare(
self, self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a1: torch.Tensor, a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int,
a2_scale: Optional[torch.Tensor], expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -169,7 +165,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -169,7 +165,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe)
logger = init_logger(__name__)
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
kernel.
"""
if not has_flashinfer_cutlass_fused_moe():
logger.debug_once("FlashInferExperts disabled: "
"flashinfer_cutlass_fused_moe not available.")
return False
# Data type checks
if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
or hidden_states.dtype
not in [torch.float32, torch.float16, torch.bfloat16]):
logger.debug_once(
"FlashInferExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
return False
return True
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_nvfp4_w4a4: bool = False,
use_fp8_w8a8: bool = False,
use_dp: bool = False,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
per_act_token_quant=False,
block_shape=None,
))
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
self.use_fp8_w8a8 = use_fp8_w8a8
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.use_dp = use_dp
assert not use_batched_format or num_dispatchers is not None
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
aq_m, aq_n = aq.shape
workspace2 = ()
output_shape = (aq_m, aq_n * 2)
workspace_dtype = a.dtype
workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape, workspace_dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
extra_expert_args: Optional[dict[str, Any]],
):
assert extra_expert_args is not None, \
"extra_expert_args must be provided"
required_keys = [
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
]
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
extract_required_args(extra_expert_args, required_keys))
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
assert not apply_router_weight_on_input
quant_scales = [
a1_gscale,
w1_scale.view(torch.int32),
g1_alphas,
a2_gscale,
w2_scale.view(torch.int32),
g2_alphas,
]
_ = flashinfer_cutlass_fused_moe(
hidden_states,
topk_ids.to(torch.int),
topk_weights,
# FlashInfer API requires weight to be long for nvfp4
w1.view(torch.long),
w2.view(torch.long),
output_dtype=out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
from vllm.utils.flashinfer import fp4_swizzle_blockscale
def get_local_sizes(local_tokens):
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
sizes = [cu_sizes[0].item()]
for i in range(1, len(cu_sizes)):
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
sizes_chunked = [max_num_tokens] * len(sizes)
if local_tokens < max_num_tokens:
# When the number of local tokens is less than max_num_tokens, all other
# ranks will also have fewer than max_num_tokens. The remaining tokens
# are accounted for as residual.
sizes_chunked = [x % max_num_tokens for x in sizes]
return sizes_chunked
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
num_dispatchers: int = 1,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
assert not apply_router_weight_on_input
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_gscale,
quant_config.quant_dtype,
self.per_channel_quant,
self.block_shape,
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
)
if use_dp:
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
dim=0,
sizes=get_local_sizes(local_tokens))
a1_m, a1_n = a1q.shape
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
(use_dp,
local_tokens) = extract_required_args(extra_finalize_args,
['use_dp', 'local_tokens'])
if use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output,
dim=0,
sizes=get_local_sizes(local_tokens),
)
output.copy_(fused_expert_output)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel.""" """Fused batched MoE kernel."""
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -496,16 +496,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -496,16 +496,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def prepare(
self, self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a1: torch.Tensor, a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int,
a2_scale: Optional[torch.Tensor], expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -594,15 +590,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -594,15 +590,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return b_a1, b_a1_scale, expert_tokens_meta, None, None return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
self, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: extra_finalize_args: Optional[dict[str, Any]]) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply( weight_and_reduce_impl.apply(
...@@ -706,7 +698,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -706,7 +698,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool): apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
...@@ -911,7 +904,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -911,7 +904,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool): apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
......
...@@ -1646,6 +1646,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1646,6 +1646,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
......
...@@ -34,6 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -34,6 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -45,6 +46,9 @@ if current_platform.is_cuda_alike(): ...@@ -45,6 +46,9 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
if has_flashinfer():
from .flashinfer_cutlass_prepare_finalize import (
FlashInferCutlassMoEPrepareAndFinalize)
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
...@@ -99,6 +103,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -99,6 +103,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_flashinfer_cutlass_kernels:
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=moe.quant_dtype, )
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens, moe.max_num_tokens,
...@@ -204,6 +211,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -204,6 +211,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm " f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize") "implementation based on the prepare_finalize")
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
pass
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
...@@ -744,12 +757,15 @@ class FusedMoE(torch.nn.Module): ...@@ -744,12 +757,15 @@ class FusedMoE(torch.nn.Module):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
if isinstance(self.quant_method, FusedMoEMethodBase):
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
# Chunked all2all staging tensor # Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size), (moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype, dtype=moe.in_dtype,
...@@ -801,6 +817,10 @@ class FusedMoE(torch.nn.Module): ...@@ -801,6 +817,10 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -1402,9 +1422,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1402,9 +1422,9 @@ class FusedMoE(torch.nn.Module):
final_hidden_states, non_blocking=True) final_hidden_states, non_blocking=True)
ctx = get_forward_context() ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0) num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp, for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank): moe_dp_chunk_size_per_rank):
...@@ -1424,13 +1444,20 @@ class FusedMoE(torch.nn.Module): ...@@ -1424,13 +1444,20 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels) and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
...@@ -1460,7 +1487,6 @@ class FusedMoE(torch.nn.Module): ...@@ -1460,7 +1487,6 @@ class FusedMoE(torch.nn.Module):
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states) final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs. # Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
......
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Any, Optional, final
import torch import torch
...@@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod @abstractmethod
def prepare( def prepare(
self, self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a1: torch.Tensor, a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int,
a2_scale: Optional[torch.Tensor], expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor], Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def finalize( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
self, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce, weight_and_reduce_impl: TopKWeightAndReduce,
) -> None: extra_finalize_args: Optional[dict[str, Any]]) -> None:
""" """
Perform any combine plus apply weights and perform a reduction on the Perform any combine plus apply weights and perform a reduction on the
fused experts output. fused experts output.
...@@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
...@@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}." f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}") f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(self, fused_out: Optional[torch.Tensor], def _do_fused_experts(
a1: torch.Tensor, a1q: torch.Tensor, self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, activation: str, global_num_experts: int, local_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool) -> torch.Tensor: apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
...@@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
return fused_out return fused_out
...@@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor: ) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
...@@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
# Chunking required case # Chunking required case
assert num_chunks > 1 assert num_chunks > 1
...@@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens, expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu) expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx)) slice_input_tensors(chunk_idx))
...@@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts, expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map) expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts( self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx), fused_out=slice_output_tensor(chunk_idx),
a1=a1, a1=a1,
...@@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=c_a1q_scale, a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale, a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
return fused_out return fused_out
...@@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is applied directly on the inputs. This is only applicable when topk is
1. 1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
extra_prepare_args,
) )
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
...@@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
self.prepare_finalize.finalize( self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids, output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl()) self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
return output return output
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import pplx_kernels as pplx import pplx_kernels as pplx
import torch import torch
...@@ -89,16 +89,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -89,16 +89,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def prepare(
self, self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a1: torch.Tensor, a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int,
a2_scale: Optional[torch.Tensor], expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -217,15 +213,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -217,15 +213,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
self, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -38,6 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -38,6 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
...@@ -48,21 +49,28 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -48,21 +49,28 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, \ assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1" "apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None return a1q, a1q_scale, None, None, None
def finalize( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
self, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: extra_finalize_args: Optional[dict[str, Any]]) -> None:
if (extra_finalize_args is not None
and extra_finalize_args.get("skip_weight_reduce", True)):
assert output.shape == fused_expert_output.shape
output.copy_(fused_expert_output)
else:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous() weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply( weight_and_reduce_impl.apply(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -119,28 +119,18 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -119,28 +119,18 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts, local_num_experts,
expert_tokens_meta) expert_tokens_meta)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
output: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): extra_expert_args: Optional[dict[str, Any]]):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2) and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_used())) or is_blackwell_deep_gemm_used()))
...@@ -168,4 +158,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -168,4 +158,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2, workspace2,
expert_tokens_meta, expert_tokens_meta,
apply_router_weight_on_input, apply_router_weight_on_input,
extra_expert_args,
) )
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod from math import prod
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.utils.flashinfer import fp4_quantize
@triton.jit @triton.jit
...@@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: ...@@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
return x.flatten()[:prod(v)].view(*v) return x.flatten()[:prod(v)].view(*v)
def _fp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_sf_swizzled_layout)
def _fp8_quantize( def _fp8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
...@@ -172,11 +183,16 @@ def moe_kernel_quantize_input( ...@@ -172,11 +183,16 @@ def moe_kernel_quantize_input(
quant_dtype: Union[None, torch.dtype, str], quant_dtype: Union[None, torch.dtype, str],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
is_fp4_scale_swizzled: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4
return _fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4": elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
else: else:
...@@ -236,3 +252,17 @@ def _validate_scale_shape( ...@@ -236,3 +252,17 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)
...@@ -339,19 +339,19 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -339,19 +339,19 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return cutlass_moe_fp4( return cutlass_moe_fp4(
a=x, a=x,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight, w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas, g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device, device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to( apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype) x.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment