"tests/vscode:/vscode.git/clone" did not exist on "4429d934de3c5cc327b0d7aec8e473aeba38db90"
Unverified Commit 268b1c55 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][13/N] Convert FI to Use PFNoEP (#31533)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Signed-off-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 4f9ce35a
......@@ -15,6 +15,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx
......@@ -77,10 +80,17 @@ def maybe_make_prepare_finalize(
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
if moe.use_flashinfer_cutlass_kernels:
assert quant_config is not None
use_deepseek_fp8_block_scale = (
quant_config is not None and quant_config.is_block_quantized
)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
if moe.use_pplx_kernels:
elif moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
......
......@@ -10,6 +10,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
......@@ -349,14 +352,23 @@ def create_flashinfer_prepare_finalize(
use_nvfp4: bool = False,
enable_alltoallv: bool = False,
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize:
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation."""
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
# once we complete the FP8 refactor.
if use_nvfp4:
if enable_alltoallv:
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# FP8 path currently supported via AllGather; optionally enable block-scale
# FP8 DP path currently supported via AllGather.
if use_dp:
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
else:
# NOTE(rob): CUTLASS FP8 block quant executes the input
# quantzation and grouped gemm in a single kernel.
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)
......@@ -49,7 +49,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
getattr(moe_layer, "shared_experts_stream", None),
moe_parallel_config=moe_layer.moe_parallel_config,
),
)
......
......@@ -356,14 +356,14 @@ class FusedMoE(CustomOp):
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once(
logger.debug_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
......
......@@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens,
disable_inplace,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_enabled,
......@@ -682,14 +681,12 @@ class FusedMoEModularKernel(torch.nn.Module):
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
......@@ -904,34 +901,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
def _maybe_setup_shared_experts_stream(
self, hidden_states: torch.Tensor
) -> tuple[bool, torch.Tensor | None]:
# decide whether to run shared experts on a separate CUDA stream to
# overlap with the main fused MoE kernel.
use_shared_experts_stream = (
self.shared_experts is not None
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream and self.shared_experts_stream is not None:
# TODO: Optimize this (complicated)
# Note: this clone adds overhead but is required
# for correctness with multiple CUDA streams and CUDA graph capture.
hidden_states_clone = hidden_states.clone()
# record that the clone will be used by the separate stream so its
# lifetime is correctly tracked.
hidden_states_clone.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
return use_shared_experts_stream, hidden_states_clone
def _prepare(
self,
hidden_states: torch.Tensor,
......@@ -1119,30 +1088,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
hidden_states_clone: torch.Tensor | None = None,
use_shared_experts_stream: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
def maybe_run_shared_experts() -> torch.Tensor | None:
if self.shared_experts is None:
return None
if (
not use_shared_experts_stream
or self.shared_experts_stream is not None
and (not hidden_states.is_cuda or not torch.cuda.is_available())
):
# fall back to running on the current stream
return self.shared_experts(hidden_states)
assert hidden_states_clone is not None
# launch shared experts on the dedicated stream.
with torch.cuda.stream(self.shared_experts_stream):
return self.shared_experts(hidden_states_clone)
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
......@@ -1155,7 +1106,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
shared_output = maybe_run_shared_experts()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
......@@ -1165,8 +1117,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
shared_output = maybe_run_shared_experts()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
......@@ -1189,28 +1141,12 @@ class FusedMoEModularKernel(torch.nn.Module):
receiver()
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
def _wait_for_shared_experts_stream(
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
) -> None:
# ensure that any work enqueued on the shared_experts_stream is
# completed before the shared_output tensor is consumed
if (
self.shared_experts is not None
and use_shared_experts_stream
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and current_platform.is_cuda()
):
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -1257,10 +1193,6 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
output = torch.zeros_like(hidden_states)
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(hidden_states)
)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
......@@ -1297,6 +1229,4 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
hidden_states_clone=hidden_states_clone,
use_shared_experts_stream=use_shared_experts_stream,
)
......@@ -48,7 +48,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
......@@ -973,27 +972,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferAllGatherMoEPrepareAndFinalize,
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.kernel = mk.FusedMoEModularKernel(
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
# with the changes to defer input quantization
FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
),
# TODO: make defer_input_quant an attr of the FlashInferExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
FlashInferExperts(
out_dtype=torch.get_default_dtype(),
out_dtype=layer.orig_dtype,
quant_config=self.moe_quant_config,
ep_rank=self.moe.ep_rank,
ep_size=self.moe.ep_size,
......@@ -1005,30 +1010,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.use_inplace = False
elif self.fp8_backend in [
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.AITER,
]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
if self.fp8_backend == Fp8MoeBackend.AITER:
elif self.fp8_backend == Fp8MoeBackend.AITER:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
......@@ -1047,7 +1029,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
......@@ -1121,20 +1102,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
# Wire block-scale flag through prepare/finalize when using CUTLASS
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
......
......@@ -46,7 +46,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
......@@ -751,13 +750,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment