Unverified Commit eb19955c authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[WideEP] Remove pplx all2all backend (#33724)


Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 0f2f24c8
......@@ -76,11 +76,4 @@ popd
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
# build and install pplx, require pytorch installed
pushd "$WORKSPACE"
git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
......@@ -4,12 +4,10 @@ set -ex
# usage: ./install_python_libraries.sh [options]
# --workspace <dir> workspace directory (default: ./ep_kernels_workspace)
# --mode <mode> "install" (default) or "wheel"
# --pplx-ref <commit> pplx-kernels commit hash
# --deepep-ref <commit> DeepEP commit hash
# --nvshmem-ver <ver> NVSHMEM version
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
NVSHMEM_VER=${NVSHMEM_VER:-"3.3.24"} # Default supports both CUDA 12 and 13
WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
......@@ -35,14 +33,6 @@ while [[ $# -gt 0 ]]; do
MODE="$2"
shift 2
;;
--pplx-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --pplx-ref requires an argument." >&2
exit 1
fi
PPLX_COMMIT_HASH="$2"
shift 2
;;
--deepep-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --deepep-ref requires an argument." >&2
......@@ -188,14 +178,6 @@ do_build() {
popd
}
# build pplx-kernels
do_build \
"https://github.com/ppl-ai/pplx-kernels" \
"pplx-kernels" \
"setup.py" \
"$PPLX_COMMIT_HASH" \
""
# build DeepEP
do_build \
"https://github.com/deepseek-ai/DeepEP" \
......
......@@ -988,7 +988,7 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor
def get_cutlass_pplx_moe_mm_data(
def get_cutlass_batched_moe_mm_data(
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
......@@ -1011,7 +1011,7 @@ def get_cutlass_pplx_moe_mm_data(
multiplication in two grouped MMs used in
the fused MoE operation.
"""
return torch.ops._C.get_cutlass_pplx_moe_mm_data(
return torch.ops._C.get_cutlass_batched_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
......
......@@ -1045,7 +1045,7 @@ class CompilationConfig:
"are optimized for prefill and are incompatible with CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
"deepep_low_latency or allgather_reducescatter."
)
self.cudagraph_mode = CUDAGraphMode.NONE
......
......@@ -152,7 +152,6 @@ class ParallelConfig:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
......@@ -310,6 +309,13 @@ class ParallelConfig:
f"but found: {self._api_process_rank}"
)
if self.all2all_backend == "pplx":
logger.warning(
"The 'pplx' all2all backend has been removed. "
"Falling back to 'allgather_reducescatter'."
)
self.all2all_backend = "allgather_reducescatter"
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
......@@ -442,7 +448,6 @@ class ParallelConfig:
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (
......
......@@ -3,14 +3,13 @@
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
......@@ -235,96 +234,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
......
......@@ -112,10 +112,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
......@@ -298,7 +294,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
......
......@@ -159,7 +159,7 @@ class EplbModelState:
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
across different dispatch methods (naive all-to-all, DeepEP).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
......@@ -24,16 +25,11 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (
......@@ -120,51 +116,10 @@ def maybe_make_prepare_finalize(
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
num_dispatchers = (
all2all_manager.world_size // all2all_manager.tp_group.world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if not all2all_manager.internode:
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers,
)
elif moe.use_deepep_ht_kernels:
if moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
all_to_all_args: dict[str, Any] = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
......
......@@ -939,10 +939,6 @@ class FusedMoEParallelConfig:
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "pplx"
@property
def use_deepep_ht_kernels(self):
return (
......@@ -962,7 +958,7 @@ class FusedMoEParallelConfig:
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
return self.use_deepep_ll_kernels
@property
def use_naive_all2all_kernels(self):
......@@ -1221,10 +1217,6 @@ class FusedMoEConfig:
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
......
......@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
ops.get_cutlass_pplx_moe_mm_data(
ops.get_cutlass_batched_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
......
......@@ -493,7 +493,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use.
that the batched dispatch/combine kernels use.
"""
def __init__(
......@@ -648,7 +648,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""
......@@ -880,7 +880,7 @@ def batched_moe_kernel_quantize_input(
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""
......
......@@ -1172,9 +1172,9 @@ class FusedMoEModularKernel(torch.nn.Module):
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
# kernels. CUDAGraph compatible all2all kernels like the DeepEP
# low-latency kernels are always batched and can never run into
# the tensor.numel() == 0 case.
if M_full == 0:
assert num_chunks == 0
workspace13 = None
......
......@@ -143,10 +143,7 @@ def select_nvfp4_moe_backend(
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
use_batched = (
config.moe_parallel_config.use_deepep_ll_kernels
or config.moe_parallel_config.use_pplx_kernels
)
use_batched = config.moe_parallel_config.use_deepep_ll_kernels
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape,
moe_kernel_quantize_input,
)
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
block_shape: list[int] | None,
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
# For blocked per token: set to
# cdiv(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
if per_act_token_quant:
# per-token (M x 1)
assert block_shape is None
hidden_scale_bytes = elem_size
elif block_shape is not None:
# per-group (M x K_tiles)
block_size = block_shape[1]
num_blocks = cdiv(hidden_dim, block_size)
hidden_scale_bytes = num_blocks * elem_size
else:
# per-tensor (1 x 1)
hidden_scale_bytes = elem_size
else:
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
hidden_scale_bytes = 0
return (
round_up(hidden_dim_bytes, align),
round_up(hidden_scale_bytes, align),
)
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""PPLX-based prepare and finalize for expert parallelism."""
def __init__(
self,
a2a: pplx.AllToAll,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
):
super().__init__()
assert max_num_tokens > 0
assert num_local_experts > 0
self.a2a = a2a
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.uint32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert topk_ids.size(0) == num_tokens
# expert_map should be None because with expert map, -1 id is used for
# non-local token; this causes error when casting ids to the
# topk_indices_dtype() int32
#
if expert_map is not None:
logger.warning_once(
"The PPLX backend does not support expert mapping. "
"The provided `expert_map` will be ignored."
)
expert_map = None # noqa: F841
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
_validate_scale_shape(
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
)
orig_a_scale_block_shape: int | None = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
# pplx requires 2-d scales even for scalar scales
if a1q_scale.dim() <= 1:
assert scalar_scales
a1q_scale = a1q_scale.view(1, 1)
orig_a_scale_block_shape = a1q_scale.shape[-1]
if not quant_config.is_block_quantized:
# TODO (bnell): use group_broadcast instead?
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
assert a1q_scale is None or a1q_scale.ndim == 2, (
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
)
expert_num_tokens = torch.empty(
self.num_local_experts,
dtype=torch.int32,
device=device,
)
expert_x = torch.empty(
(
self.num_local_experts,
self.max_num_tokens * self.num_dispatchers(),
hidden_dim,
),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: torch.Tensor | None = None
if a1q.dtype.itemsize == 1:
if quant_config.is_per_act_token:
# (M x 1) -> (E x M x K)
final_dim = expert_x.size(2)
elif quant_config.is_per_tensor:
# (1 x 1) -> (E x 1 x 1)
final_dim = 1
else:
# (M x K_tiles) -> (E x M x K_tiles)
assert quant_config.block_shape is not None
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
final_dim = num_blocks
expert_x_scale_shape = (
self.num_local_experts,
expert_x.size(1),
round_up(final_dim, 4), # round up for alignment
)
expert_x_scale = torch.empty(
expert_x_scale_shape,
dtype=torch.float32,
device=expert_x.device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: torch.Tensor | None = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
return (
hook,
lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
orig_a_scale_block_shape,
),
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: torch.Tensor | None,
orig_a_scale_block_shape: int | None,
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant=defer_input_quant,
)
hook()
return receiver()
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
"Weight application and reduction happens in the combine kernel."
)
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: torch.Tensor | None = None
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
# num_tokens = output.size(0) # M
# assert topk_ids.size(0) == num_tokens, (
# f"{topk_ids.size(0)} == {num_tokens}")
assert topk_ids.size() == topk_weights.size(), (
f"{topk_ids.size()} == {topk_weights.size()}"
)
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}"
)
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()
......@@ -216,8 +216,7 @@ class DefaultMoERunner(MoERunner):
@property
def use_dp_chunking(self) -> bool:
return (
self.moe_config.moe_parallel_config.use_pplx_kernels
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
......
......@@ -14,10 +14,11 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
For example, BatchedTritonExperts is compatible with both
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
does the weight-application + reduction as part of the pplx combine kernel.
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
For example, BatchedTritonExperts is compatible with both batched
PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
the weight-application + reduction as part of the combine kernel, while
BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
......
......@@ -798,7 +798,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
......
......@@ -402,11 +402,6 @@ def _has_module(module_name: str) -> bool:
return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment