Unverified Commit da364615 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Modular kernel refactor (#24812)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent f08919b7
...@@ -209,18 +209,18 @@ class Config: ...@@ -209,18 +209,18 @@ class Config:
info = prepare_finalize_info(self.prepare_finalize_type) info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend return info.backend
def is_valid(self): def is_valid(self) -> tuple[bool, Optional[str]]:
# Check prepare-finalize and fused-experts compatibility # Check prepare-finalize and fused-experts compatibility
if self.is_batched_prepare_finalize(): if self.is_batched_prepare_finalize():
if not self.is_batched_fused_experts(): if not self.is_batched_fused_experts():
return False return False, "Mismatched format."
else: else:
if not self.is_standard_fused_experts(): if not self.is_standard_fused_experts():
return False return False, "Mismatched format."
use_chunking = self.fused_moe_chunk_size is not None use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking(): if use_chunking and not self.is_fe_supports_chunking():
return False return False, "Chunking not supported."
# Check quantization sanity # Check quantization sanity
if ( if (
...@@ -229,7 +229,7 @@ class Config: ...@@ -229,7 +229,7 @@ class Config:
+ int(self.quant_block_shape is not None) + int(self.quant_block_shape is not None)
) > 1: ) > 1:
# invalid quant config # invalid quant config
return False return False, f"Bad quant_config {self.quant_config}."
# check type support # check type support
if self.quant_dtype is None: if self.quant_dtype is None:
...@@ -237,34 +237,43 @@ class Config: ...@@ -237,34 +237,43 @@ class Config:
self.dtype not in self.pf_supported_types() self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types() or self.dtype not in self.fe_supported_types()
): ):
return False return False, (
f"Unsupported type {self.dtype} not in "
f"{self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)
else: else:
if ( if (
self.quant_dtype not in self.pf_supported_types() self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types() or self.quant_dtype not in self.fe_supported_types()
): ):
return False return False, (
f"Unsupported quant type {self.quant_dtype} "
f"not in {self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)
# Check block quanization support # Check block quanization support
is_block_quatized = self.quant_block_shape is not None is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and self.quant_dtype is None: if is_block_quatized and self.quant_dtype is None:
return False return False, "No block quantization support."
if is_block_quatized and not self.is_block_quant_supported(): if is_block_quatized and not self.is_block_quant_supported():
return False return False, "Mismatched block quantization support."
# deep_gemm only works with block-quantized # deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized: if self.needs_deep_gemm() and not is_block_quatized:
return False return False, "Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?) # Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep(): if self.needs_deep_ep() and not has_deep_ep():
return False return False, "Needs DeepEP, but DeepEP not available."
if self.needs_deep_gemm() and not has_deep_gemm(): if self.needs_deep_gemm() and not has_deep_gemm():
return False return False, "Needs DeepGEMM, but DeepGEMM not available."
if self.needs_pplx() and not has_pplx(): # noqa: SIM103 if self.needs_pplx() and not has_pplx(): # noqa: SIM103
return False return False, "Needs PPLX, but PPLX not available."
return True return True, None
@dataclass @dataclass
......
...@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str): ...@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
) )
success = None success = None
if config.is_valid(): if config.is_valid()[0]:
print(f"Running config : {config.describe()} ...") print(f"Running config : {config.describe()} ...")
try: try:
weights: WeightTensors = WeightTensors.make(config) weights: WeightTensors = WeightTensors.make(config)
......
...@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability ...@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_prepare_and_finalize( register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize, FlashInferCutlassMoEPrepareAndFinalize,
standard_format, standard_format,
nvfp4_types, nvfp4_types + fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
backend=None, backend=None,
force_multigpu=True, force_multigpu=True,
...@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability ...@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_experts( register_experts(
FlashInferExperts, FlashInferExperts,
standard_format, standard_format,
nvfp4_types, nvfp4_types + fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True, supports_chunking=True,
# Note: this is a hack to get it to run for now # Note: this is a hack to get it to run for now
...@@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported(): ...@@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False, needs_matching_quant=False,
needs_deep_gemm=True, needs_deep_gemm=True,
) )
( register_experts(
register_experts( DeepGemmExperts,
DeepGemmExperts, standard_format,
standard_format, fp8_types,
fp8_types, blocked_quantization_support=True,
blocked_quantization_support=True, supports_chunking=True,
supports_chunking=True, supports_expert_map=True,
supports_expert_map=True, needs_matching_quant=False,
needs_matching_quant=False, needs_deep_gemm=True,
needs_deep_gemm=True,
),
) )
register_experts( register_experts(
BatchedTritonOrDeepGemmExperts, BatchedTritonOrDeepGemmExperts,
...@@ -464,7 +462,7 @@ def make_fused_experts( ...@@ -464,7 +462,7 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs) experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts: elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts {quant_config} ...") print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config) experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts: elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs kwargs = quant_kwargs
......
...@@ -5,7 +5,7 @@ import copy ...@@ -5,7 +5,7 @@ import copy
import textwrap import textwrap
import traceback import traceback
from itertools import product from itertools import product
from typing import Optional from typing import Any, Optional
import pytest import pytest
import torch import torch
...@@ -13,10 +13,9 @@ import torch ...@@ -13,10 +13,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
Config, Config,
RankTensors, RankTensors,
...@@ -132,7 +131,8 @@ def rank_worker( ...@@ -132,7 +131,8 @@ def rank_worker(
def run(config: Config, verbose: bool): def run(config: Config, verbose: bool):
assert config.is_valid() assert config.is_valid()[0]
assert not is_nyi_config(config)
weights: WeightTensors = WeightTensors.make(config) weights: WeightTensors = WeightTensors.make(config)
...@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool: ...@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
return not info.supports_expert_map return not info.supports_expert_map
@pytest.mark.parametrize("k", Ks) def generate_valid_test_cases(
@pytest.mark.parametrize("n", Ns) world_size: int, prepare_finalize_types
@pytest.mark.parametrize("e", Es) ) -> list[tuple[Any, ...]]:
@pytest.mark.parametrize("dtype", DTYPEs) cases = []
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) total = 0
for k, n, e, dtype, quant_config, combination, chunk_size in product(
Ks,
Ns,
Es,
DTYPEs,
MK_QUANT_CONFIGS,
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
FUSED_MOE_CHUNK_SIZEs,
):
total = total + 1
config = Config(
Ms=Ms,
K=k,
N=n,
E=e,
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)
# TODO(bnell): figure out how to get verbose flag here.
verbose = False # pytestconfig.getoption('verbose') > 0
valid, reason = config.is_valid()
if not valid:
if verbose:
print(f"Test config {config} is not valid: {reason}")
continue
if is_nyi_config(config):
if verbose:
print(f"Test config {config} is nyi.")
continue
cases.append(
(
k,
n,
e,
dtype,
quant_config,
combination[0],
combination[1],
chunk_size,
world_size,
)
)
print(f"{len(cases)} of {total} valid configs generated.")
return cases
@pytest.mark.parametrize( @pytest.mark.parametrize(
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
) )
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, k: int,
...@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu( ...@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
e: int, e: int,
dtype: torch.dtype, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig], quant_config: Optional[TestMoEQuantConfig],
combination: tuple[ prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
], chunk_size: Optional[int],
fused_moe_chunk_size: Optional[int],
world_size: int, world_size: int,
pytestconfig, pytestconfig,
): ):
if cuda_device_count_stateless() < world_size:
pytest.skip(
f"Not enough GPUs available to run, got "
f"{cuda_device_count_stateless()} exepected "
f"{world_size}."
)
config = Config( config = Config(
Ms=Ms, Ms=Ms,
K=k, K=k,
...@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu( ...@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks=TOPKs, topks=TOPKs,
dtype=dtype, dtype=dtype,
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=combination[0], prepare_finalize_type=prepare_finalize_type,
fused_experts_type=combination[1], fused_experts_type=fused_experts_type,
fused_moe_chunk_size=fused_moe_chunk_size, fused_moe_chunk_size=chunk_size,
world_size=world_size, world_size=world_size,
) )
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption("verbose") verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0) run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
) )
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu( def test_modular_kernel_combinations_singlegpu(
k: int, k: int,
n: int, n: int,
e: int, e: int,
dtype: torch.dtype, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig], quant_config: Optional[TestMoEQuantConfig],
combination: tuple[ prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
], chunk_size: Optional[int],
fused_moe_chunk_size: Optional[int],
world_size: int, world_size: int,
pytestconfig, pytestconfig,
): ):
...@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks=TOPKs, topks=TOPKs,
dtype=dtype, dtype=dtype,
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=combination[0], prepare_finalize_type=prepare_finalize_type,
fused_experts_type=combination[1], fused_experts_type=fused_experts_type,
fused_moe_chunk_size=fused_moe_chunk_size, fused_moe_chunk_size=chunk_size,
world_size=world_size, world_size=world_size,
) )
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption("verbose") verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0) run(config, verbosity > 0)
......
...@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader # FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks # DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed. # end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers num_dispatchers = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = ( max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
a.size(0) if self.max_num_tokens is None else self.max_num_tokens
)
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output)
def apply( def apply(
self, self,
...@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w2.size(1) == K assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( E, max_num_tokens, N, K, _ = self.moe_problem_size(
hidden_states, w1, w2, topk_ids hidden_states, w1, w2, topk_ids
) )
......
...@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert bte_war is not None assert bte_war is not None
return bte_war return bte_war
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return act_dtype
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm: if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes(
a,
aq,
M, M,
N, N,
K, K,
...@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else: else:
assert self.batched_triton_experts is not None assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a,
aq,
M, M,
N, N,
K, K,
......
...@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel # topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K)) workspace2 = (M * topk, max(N // 2, K))
output = (M, K) output = (M, K)
return ( return (workspace1, workspace2, output)
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
padded_M = aq.size(1)
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
assert num_dp is not None assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K)) workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K) output = (self.max_experts_per_worker, M, K)
return ( return (workspace1, workspace2, output)
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
def cutlass_moe_fp8( def cutlass_moe_fp8(
...@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1: tuple[int, ...] = () workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = () workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = () output: tuple[int, ...] = ()
if self.use_batched_format: if self.use_batched_format:
padded_M = aq.size(1) workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, M, (N // 2))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output = (self.max_experts_per_worker, M, K)
output = (self.max_experts_per_worker, padded_M, K)
else: else:
workspace1 = (M * topk, max(2 * N, K)) workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N) workspace2 = (M * topk, N)
output = (M, K) output = (M, K)
return ( return (workspace1, workspace2, output)
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
def apply( def apply(
self, self,
......
...@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None assert self.block_shape is not None
block_m = self.block_shape[0] block_m = self.block_shape[0]
M_sum = compute_aligned_M( M_sum = compute_aligned_M(
...@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = (M_sum, max(N, K)) workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K)) workspace2 = (M_sum, max(N // 2, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output)
def apply( def apply(
self, self,
......
...@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
......
...@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts return mk.FusedMoEActivationFormat.BatchedExperts
......
...@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles # We use global_num_experts due to how moe_align_block_size handles
# expert_maps. # expert_maps.
""" """
...@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension - Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens. of each tuple must be the number of tokens.
""" """
aq_m, aq_n = aq.shape workspace1 = (M, K)
workspace2 = (0,) workspace2 = (0,)
output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n) output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
workspace_dtype = a.dtype
workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation. # potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape, workspace_dtype) return (workspace1, workspace2, output_shape)
def apply( def apply(
self, self,
......
...@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ...@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
) )
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave from vllm.utils.flashinfer import nvfp4_block_scale_interleave
...@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def _apply_router_weight_on_input( def _apply_router_weight_on_input(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin ...@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
if self.use_dp: if self.use_dp:
fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes() fused_expert_output, dim=0, sizes=get_local_sizes()
......
...@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert a.dim() == 2
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N) workspace2 = (self.max_num_tokens * num_dp, N)
output = workspace13 output = workspace13
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output)
def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
assert self.quant_config.is_quantized assert self.quant_config.is_quantized
...@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert a.dim() == 2
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = self.max_num_tokens max_num_tokens = self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output)
def apply( def apply(
self, self,
......
...@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in # Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined # the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as, # essentially as,
...@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2 = (M * topk * max(2 * N, K),) workspace2 = (M * topk * max(2 * N, K),)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output)
def apply( def apply(
self, self,
......
...@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K)) workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K)) workspace2 = (M, topk, max(N, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output)
def apply( def apply(
self, self,
......
...@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
workspace1 = (M, K) workspace1 = (M, K)
workspace2 = (0, 0) workspace2 = (0, 0)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output)
def apply( def apply(
self, self,
......
...@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
) -> Optional[FusedMoEQuantConfig]: ) -> Optional[FusedMoEQuantConfig]:
raise NotImplementedError raise NotImplementedError
@property
def using_modular_kernel(self) -> bool:
return self.fused_experts is not None
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
...@@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp): ...@@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
# TODO(bnell): flashinfer uses non-batched format. if self.use_dp_chunking:
# Does it really need a batched buffer? states_shape: tuple[int, ...]
if ( logits_shape: tuple[int, ...]
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels
):
if vllm_config.parallel_config.enable_dbo:
self.batched_hidden_states = torch.zeros(
(2, moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
# Note here we use `num_experts` which is logical expert count # Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros( if vllm_config.parallel_config.enable_dbo:
(2, moe.max_num_tokens, num_experts), states_shape = (2, moe.max_num_tokens, self.hidden_size)
dtype=moe.in_dtype, logits_shape = (2, moe.max_num_tokens, num_experts)
device=torch.cuda.current_device(),
)
else: else:
self.batched_hidden_states = torch.zeros( states_shape = (moe.max_num_tokens, self.hidden_size)
(moe.max_num_tokens, self.hidden_size), logits_shape = (moe.max_num_tokens, num_experts)
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
# Note here we use `num_experts` which is logical expert count self.batched_hidden_states = torch.zeros(
self.batched_router_logits = torch.zeros( states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
(moe.max_num_tokens, num_experts), )
dtype=moe.in_dtype,
device=torch.cuda.current_device(), self.batched_router_logits = torch.zeros(
) logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> Optional[torch.nn.Module]:
...@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp): ...@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
and self.moe_config.use_flashinfer_cutlass_kernels and self.moe_config.use_flashinfer_cutlass_kernels
) )
@property
def use_dp_chunking(self) -> bool:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
)
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self.expert_map is not None
...@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp): ...@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output Therefore it is required that we reduce the shared_experts output
early. early.
""" """
assert self.quant_method is not None
return ( return (
self.use_pplx_kernels self.quant_method.fused_experts is not None
or self.use_deepep_ht_kernels and self.quant_method.fused_experts.output_is_reduced()
or self.use_deepep_ll_kernels
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
""" """
The pplx combine kernel reduces across GPU ranks by default. Some combine kernels reduce across GPU ranks by default.
""" """
if ( if self.must_reduce_shared_expert_outputs():
self.use_pplx_kernels
or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels
):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp): ...@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config() self.ensure_moe_quant_config()
# Route to the chunked forward path using the FlashInfer Cutlass kernel if self.use_dp_chunking:
# only when data parallelism (DP) is enabled.
_use_flashinfer_cutlass_kernels = (
self.dp_size > 1 and self.use_flashinfer_cutlass_kernels
)
if (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or _use_flashinfer_cutlass_kernels
):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1 and not self.quant_method.using_modular_kernel
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels
) )
# If there are shared experts but we are not using a modular kernel, the # If there are shared experts but we are not using a modular kernel, the
......
...@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool: def supports_async(self) -> bool:
return True return True
......
...@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return 1 return 1
def output_is_reduced(self) -> bool:
return False
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment