Unverified Commit 7ef40bb9 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 767cbb01
...@@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels ...@@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | | deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] | | marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
...@@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor ...@@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| |----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| deepep_high_throughput,</br>pplx | `DeepEPHTPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8` | | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8` | | deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
...@@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w2.size(1) == K assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
......
...@@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2 n = w2.shape[2] * 2
run_cutlass_moe_fp4( run_cutlass_moe_fp4(
......
...@@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E assert w1.size(0) == E
......
...@@ -4,11 +4,18 @@ ...@@ -4,11 +4,18 @@
from typing import Optional from typing import Optional
import torch import torch
from typing_extensions import override
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add) marlin_make_workspace_new, marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
bias2: Optional[torch.Tensor], bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int, quant_type_id: int,
...@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True, is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor: inplace: bool = False) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1. - w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2. - w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (Optional[torch.Tensor]): The output of the gating
(before softmax). operation (before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input - sort_indices1 (Optional[torch.Tensor]): The first act_order input
...@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
num_bits = 4 if quant_type in bit4_scalar_types else 8 num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints. # Check constraints.
assert hidden_states.shape[0] == gating_output.shape[ if gating_output is not None:
0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
assert hidden_states.shape[ assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1" 1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // ( assert hidden_states.shape[1] == w2.shape[2] // (
...@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = marlin_moe_intermediate_size(w1, w2)
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
# M block size selection logic # M block size selection logic
...@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if workspace is None: if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4) workspace = marlin_make_workspace_new(hidden_states.device, 4)
intermediate_cache2 = torch.empty( if intermediate_cache2 is None:
(M * topk_ids.shape[1], N), intermediate_cache2 = torch.empty(
device=hidden_states.device, (M * topk, N),
dtype=hidden_states.dtype, device=hidden_states.device,
) dtype=hidden_states.dtype,
intermediate_cache13 = torch.empty( )
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device, if intermediate_cache13 is None:
dtype=hidden_states.dtype, intermediate_cache13 = torch.empty(
) (M * topk * max(2 * N, K), ),
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] device=hidden_states.device,
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) dtype=hidden_states.dtype,
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] )
intermediate_cache3 = intermediate_cache3.view(-1, K)
intermediate_cache1 = _resize_cache(intermediate_cache13,
(M * topk, 2 * N))
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \ use_atomic_add = hidden_states.dtype == torch.half or \
...@@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K) is_zp_float=False).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states) if output is None:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), output = hidden_states if inplace else torch.empty_like(hidden_states)
dim=1, return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
out=output)
def fused_marlin_moe_fake(hidden_states: torch.Tensor, def fused_marlin_moe_fake(hidden_states: torch.Tensor,
...@@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, ...@@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int, quant_type_id: int,
...@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, ...@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True, is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor: inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -237,3 +254,124 @@ direct_register_custom_op( ...@@ -237,3 +254,124 @@ direct_register_custom_op(
op_func=fused_marlin_moe, op_func=fused_marlin_moe,
fake_impl=fused_marlin_moe_fake, fake_impl=fused_marlin_moe_fake,
) )
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@override
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
assert w1.dim() == 3 and w2.dim() == 3
E = w1.size(0)
K = a1.size(-1)
N = marlin_moe_intermediate_size(w1, w2)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
# `torch.sum(workspace1, dim=1, out=output)`
# Having overlapping input and output tensors for torch.sum seems
# error prone and depends on how the torch.sum is implemented.
# For this reason we swap let the output buffer provision from
# workspace2.
# Workspace/IntermediateCache allocation matching fused_marlin_moe()
#workspace1 = (M * topk * max(2 * N, K),)
#workspace2 = (M * topk, N)
# Workspace/IntermediateCache allocation accounting for output buffer
# provisioning
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk * max(2 * N, K), )
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.w1_scale is not None
assert self.w2_scale is not None
return fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
bias1=self.w1_bias,
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
output=output,
# Workspaces are swapped in workspace_shapes() to account for proper
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13=workspace2,
intermediate_cache2=workspace13)
...@@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
] ]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size( E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)
if global_num_experts == -1: if global_num_experts == -1:
......
...@@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, ...@@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
# #
def _moe_problem_size(
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is
not obvious. It needs to be done this way specifically due to subtle issues
with particular kernels, e.g. the int4 kernels divide the trailing dimension
by two, so it's not "correct" to extract N or K from the trailing dimension
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
to be kept in mind.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum): class FusedMoEActivationFormat(Enum):
""" """
The standard activation format (num_tokens, hidden dim). The standard activation format (num_tokens, hidden dim).
...@@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
# #
# Various helpers for accessing quantization parameters from the # Various helpers for accessing quantization parameters from the
# quant_config. # quant_config.
...@@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape, (workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes( workspace_dtype) = self.fused_experts.workspace_shapes(
...@@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE) num_chunks = cdiv(M, CHUNK_SIZE)
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config) mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts) OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
...@@ -92,7 +93,7 @@ def get_mxfp4_backend(): ...@@ -92,7 +93,7 @@ def get_mxfp4_backend():
"Please `pip install vllm[flashinfer]` for best results.") "Please `pip install vllm[flashinfer]` for best results.")
# If FlashInfer is not available, try either Marlin or Triton # If FlashInfer is not available, try either Marlin or Triton
if current_platform.get_device_capability( if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
"2.8.0"): "2.8.0"):
logger.info_once("Using Marlin backend") logger.info_once("Using Marlin backend")
...@@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return None return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
if self.mxfp4_backend == Mxfp4Backend.TRITON: w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config( return mxfp4_w4a16_moe_quant_config(
...@@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
} }
return TrtLlmGenExperts(self.moe, self.moe_quant_config, return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs) **kwargs)
elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
return MarlinExperts(self.moe_quant_config)
else: else:
return OAITritonExperts(self.moe_quant_config) return OAITritonExperts(self.moe_quant_config)
...@@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation, activation=activation,
expert_map=expert_map) expert_map=expert_map)
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
assert _can_support_mxfp4( assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map, use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias, custom_routing_function, e_score_correction_bias,
......
...@@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ ...@@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight and supports_activation supports_router_weight and supports_activation
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
w2_packed: torch.Tensor):
"""
Given Marlin packed weight matrices w1_packed, and w2_packed,
return the MoE intermediate size N
"""
marlin_tile_size = 16
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition // max_workspace_size = (output_size_per_partition //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment