Unverified Commit da364615 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Modular kernel refactor (#24812)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent f08919b7
......@@ -209,18 +209,18 @@ class Config:
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend
def is_valid(self):
def is_valid(self) -> tuple[bool, Optional[str]]:
# Check prepare-finalize and fused-experts compatibility
if self.is_batched_prepare_finalize():
if not self.is_batched_fused_experts():
return False
return False, "Mismatched format."
else:
if not self.is_standard_fused_experts():
return False
return False, "Mismatched format."
use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False
return False, "Chunking not supported."
# Check quantization sanity
if (
......@@ -229,7 +229,7 @@ class Config:
+ int(self.quant_block_shape is not None)
) > 1:
# invalid quant config
return False
return False, f"Bad quant_config {self.quant_config}."
# check type support
if self.quant_dtype is None:
......@@ -237,34 +237,43 @@ class Config:
self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()
):
return False
return False, (
f"Unsupported type {self.dtype} not in "
f"{self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)
else:
if (
self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()
):
return False
return False, (
f"Unsupported quant type {self.quant_dtype} "
f"not in {self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)
# Check block quanization support
is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and self.quant_dtype is None:
return False
return False, "No block quantization support."
if is_block_quatized and not self.is_block_quant_supported():
return False
return False, "Mismatched block quantization support."
# deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized:
return False
return False, "Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep():
return False
return False, "Needs DeepEP, but DeepEP not available."
if self.needs_deep_gemm() and not has_deep_gemm():
return False
return False, "Needs DeepGEMM, but DeepGEMM not available."
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
return False
return False, "Needs PPLX, but PPLX not available."
return True
return True, None
@dataclass
......
......@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
)
success = None
if config.is_valid():
if config.is_valid()[0]:
print(f"Running config : {config.describe()} ...")
try:
weights: WeightTensors = WeightTensors.make(config)
......
......@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nvfp4_types,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
......@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_experts(
FlashInferExperts,
standard_format,
nvfp4_types,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
......@@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
(
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
BatchedTritonOrDeepGemmExperts,
......@@ -464,7 +462,7 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts {quant_config} ...")
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
......
......@@ -5,7 +5,7 @@ import copy
import textwrap
import traceback
from itertools import product
from typing import Optional
from typing import Any, Optional
import pytest
import torch
......@@ -13,10 +13,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (
Config,
RankTensors,
......@@ -132,7 +131,8 @@ def rank_worker(
def run(config: Config, verbose: bool):
assert config.is_valid()
assert config.is_valid()[0]
assert not is_nyi_config(config)
weights: WeightTensors = WeightTensors.make(config)
......@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
return not info.supports_expert_map
@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
def generate_valid_test_cases(
world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
cases = []
total = 0
for k, n, e, dtype, quant_config, combination, chunk_size in product(
Ks,
Ns,
Es,
DTYPEs,
MK_QUANT_CONFIGS,
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
FUSED_MOE_CHUNK_SIZEs,
):
total = total + 1
config = Config(
Ms=Ms,
K=k,
N=n,
E=e,
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)
# TODO(bnell): figure out how to get verbose flag here.
verbose = False # pytestconfig.getoption('verbose') > 0
valid, reason = config.is_valid()
if not valid:
if verbose:
print(f"Test config {config} is not valid: {reason}")
continue
if is_nyi_config(config):
if verbose:
print(f"Test config {config} is nyi.")
continue
cases.append(
(
k,
n,
e,
dtype,
quant_config,
combination[0],
combination[1],
chunk_size,
world_size,
)
)
print(f"{len(cases)} of {total} valid configs generated.")
return cases
@pytest.mark.parametrize(
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int,
......@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
if cuda_device_count_stateless() < world_size:
pytest.skip(
f"Not enough GPUs available to run, got "
f"{cuda_device_count_stateless()} exepected "
f"{world_size}."
)
config = Config(
Ms=Ms,
K=k,
......@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
......@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
......
......@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = (
a.size(0) if self.max_num_tokens is None else self.max_num_tokens
)
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
return (workspace13, workspace2, output)
def apply(
self,
......@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
E, max_num_tokens, N, K, _ = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
......
......@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert bte_war is not None
return bte_war
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes(
a,
aq,
M,
N,
K,
......@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a,
aq,
M,
N,
K,
......
......@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
return (workspace1, workspace2, output)
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
......@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool:
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
padded_M = aq.size(1)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, M, K)
return (workspace1, workspace2, output)
def cutlass_moe_fp8(
......@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, (N // 2))
output = (self.max_experts_per_worker, M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
return (workspace1, workspace2, output)
def apply(
self,
......
......@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
......@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def apply(
self,
......
......@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......
......@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
......
......@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
......@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
aq_m, aq_n = aq.shape
workspace1 = (M, K)
workspace2 = (0,)
output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n)
workspace_dtype = a.dtype
workspace1 = output_shape
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape, workspace_dtype)
return (workspace1, workspace2, output_shape)
def apply(
self,
......
......@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
......@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def _apply_router_weight_on_input(
self,
a1: torch.Tensor,
......@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
if self.use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes()
......
......@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
......@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N)
output = workspace13
return (workspace13, workspace2, output, a.dtype)
return (workspace13, workspace2, output)
def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
assert self.quant_config.is_quantized
......@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
return (workspace13, workspace2, output)
def apply(
self,
......
......@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
......@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2 = (M * topk * max(2 * N, K),)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def apply(
self,
......
......@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def apply(
self,
......
......@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def apply(
self,
......
......@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
) -> Optional[FusedMoEQuantConfig]:
raise NotImplementedError
@property
def using_modular_kernel(self) -> bool:
return self.fused_experts is not None
@abstractmethod
def apply(
self,
......@@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels
):
if vllm_config.parallel_config.enable_dbo:
self.batched_hidden_states = torch.zeros(
(2, moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
if self.use_dp_chunking:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(2, moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
# Note here we use `num_experts` which is logical expert count
if vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, num_experts)
else:
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, num_experts)
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device(),
)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
......@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
and self.moe_config.use_flashinfer_cutlass_kernels
)
@property
def use_dp_chunking(self) -> bool:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
)
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
......@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert self.quant_method is not None
return (
self.use_pplx_kernels
or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels
self.quant_method.fused_experts is not None
and self.quant_method.fused_experts.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
The pplx combine kernel reduces across GPU ranks by default.
Some combine kernels reduce across GPU ranks by default.
"""
if (
self.use_pplx_kernels
or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels
):
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
_use_flashinfer_cutlass_kernels = (
self.dp_size > 1 and self.use_flashinfer_cutlass_kernels
)
if (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or _use_flashinfer_cutlass_kernels
):
if self.use_dp_chunking:
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels
self.dp_size > 1 and not self.quant_method.using_modular_kernel
)
# If there are shared experts but we are not using a modular kernel, the
......
......@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool:
return True
......
......@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment