"vscode:/vscode.git/clone" did not exist on "24d0ef89705e0ab8df3d79fcbfd669cf5575772b"
Unverified Commit eb19955c authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[WideEP] Remove pplx all2all backend (#33724)


Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 0f2f24c8
...@@ -76,11 +76,4 @@ popd ...@@ -76,11 +76,4 @@ popd
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
# build and install pplx, require pytorch installed
pushd "$WORKSPACE"
git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
...@@ -4,12 +4,10 @@ set -ex ...@@ -4,12 +4,10 @@ set -ex
# usage: ./install_python_libraries.sh [options] # usage: ./install_python_libraries.sh [options]
# --workspace <dir> workspace directory (default: ./ep_kernels_workspace) # --workspace <dir> workspace directory (default: ./ep_kernels_workspace)
# --mode <mode> "install" (default) or "wheel" # --mode <mode> "install" (default) or "wheel"
# --pplx-ref <commit> pplx-kernels commit hash
# --deepep-ref <commit> DeepEP commit hash # --deepep-ref <commit> DeepEP commit hash
# --nvshmem-ver <ver> NVSHMEM version # --nvshmem-ver <ver> NVSHMEM version
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"} DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
NVSHMEM_VER=${NVSHMEM_VER:-"3.3.24"} # Default supports both CUDA 12 and 13 NVSHMEM_VER=${NVSHMEM_VER:-"3.3.24"} # Default supports both CUDA 12 and 13
WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace} WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
...@@ -35,14 +33,6 @@ while [[ $# -gt 0 ]]; do ...@@ -35,14 +33,6 @@ while [[ $# -gt 0 ]]; do
MODE="$2" MODE="$2"
shift 2 shift 2
;; ;;
--pplx-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --pplx-ref requires an argument." >&2
exit 1
fi
PPLX_COMMIT_HASH="$2"
shift 2
;;
--deepep-ref) --deepep-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --deepep-ref requires an argument." >&2 echo "Error: --deepep-ref requires an argument." >&2
...@@ -188,14 +178,6 @@ do_build() { ...@@ -188,14 +178,6 @@ do_build() {
popd popd
} }
# build pplx-kernels
do_build \
"https://github.com/ppl-ai/pplx-kernels" \
"pplx-kernels" \
"setup.py" \
"$PPLX_COMMIT_HASH" \
""
# build DeepEP # build DeepEP
do_build \ do_build \
"https://github.com/deepseek-ai/DeepEP" \ "https://github.com/deepseek-ai/DeepEP" \
......
...@@ -988,7 +988,7 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): ...@@ -988,7 +988,7 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor return output_tensor
def get_cutlass_pplx_moe_mm_data( def get_cutlass_batched_moe_mm_data(
expert_offsets: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor, problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor, problem_sizes2: torch.Tensor,
...@@ -1011,7 +1011,7 @@ def get_cutlass_pplx_moe_mm_data( ...@@ -1011,7 +1011,7 @@ def get_cutlass_pplx_moe_mm_data(
multiplication in two grouped MMs used in multiplication in two grouped MMs used in
the fused MoE operation. the fused MoE operation.
""" """
return torch.ops._C.get_cutlass_pplx_moe_mm_data( return torch.ops._C.get_cutlass_batched_moe_mm_data(
expert_offsets, expert_offsets,
problem_sizes1, problem_sizes1,
problem_sizes2, problem_sizes2,
......
...@@ -1045,7 +1045,7 @@ class CompilationConfig: ...@@ -1045,7 +1045,7 @@ class CompilationConfig:
"are optimized for prefill and are incompatible with CUDA Graphs. " "are optimized for prefill and are incompatible with CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, " "In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as " "use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter." "deepep_low_latency or allgather_reducescatter."
) )
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
......
...@@ -152,7 +152,6 @@ class ParallelConfig: ...@@ -152,7 +152,6 @@ class ParallelConfig:
- "naive": Naive all2all implementation using broadcasts\n - "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n - "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n - "mori": Use mori kernels\n
...@@ -310,6 +309,13 @@ class ParallelConfig: ...@@ -310,6 +309,13 @@ class ParallelConfig:
f"but found: {self._api_process_rank}" f"but found: {self._api_process_rank}"
) )
if self.all2all_backend == "pplx":
logger.warning(
"The 'pplx' all2all backend has been removed. "
"Falling back to 'allgather_reducescatter'."
)
self.all2all_backend = "allgather_reducescatter"
if self.data_parallel_size_local > self.data_parallel_size: if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError( raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) " f"data_parallel_size_local ({self.data_parallel_size_local}) "
...@@ -442,7 +448,6 @@ class ParallelConfig: ...@@ -442,7 +448,6 @@ class ParallelConfig:
# In this case, ensure the input to the experts is sequence parallel # In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work. # to avoid the excess work.
# #
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property @property
def use_sequence_parallel_moe(self) -> bool: def use_sequence_parallel_moe(self) -> bool:
return ( return (
......
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
from typing import Any from typing import Any
import torch import torch
import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache from .base_device_communicator import All2AllManagerBase, Cache
...@@ -235,96 +234,6 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -235,96 +234,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase): class DeepEPAll2AllManagerBase(All2AllManagerBase):
""" """
All2All communication based on DeepEP High-Throughput kernels. All2All communication based on DeepEP High-Throughput kernels.
......
...@@ -112,10 +112,6 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -112,10 +112,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import AgRsAll2AllManager from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group) self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
elif self.all2all_backend == "deepep_high_throughput": elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager from .all2all import DeepEPHTAll2AllManager
...@@ -298,7 +294,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -298,7 +294,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None self.fi_ar_comm = None
if self.all2all_manager is not None: if self.all2all_manager is not None:
self.all2all_manager.destroy() self.all2all_manager.destroy()
self.all2all_manager = None self.all2all_manager = None # type: ignore[assignment]
def all_gatherv( def all_gatherv(
self, self,
......
...@@ -159,7 +159,7 @@ class EplbModelState: ...@@ -159,7 +159,7 @@ class EplbModelState:
NOTE: The expert_load_view now records load for all physical experts NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). across different dispatch methods (naive all-to-all, DeepEP).
The recorded load will be multiplied by dp_size when using naive all-to-all The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation. due to each DP rank contributing the same token set to the calculation.
See: See:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch import torch
...@@ -24,16 +25,11 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( ...@@ -24,16 +25,11 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.import_utils import has_deep_ep, has_mori
logger = init_logger(__name__) logger = init_logger(__name__)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
if has_deep_ep(): if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import ( from .deepep_ll_prepare_finalize import (
...@@ -120,51 +116,10 @@ def maybe_make_prepare_finalize( ...@@ -120,51 +116,10 @@ def maybe_make_prepare_finalize(
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
if moe.use_pplx_kernels: if moe.use_deepep_ht_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
num_dispatchers = (
all2all_manager.world_size // all2all_manager.tp_group.world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if not all2all_manager.internode:
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers,
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict() all_to_all_args: dict[str, Any] = dict()
handle = all2all_manager.get_handle(all_to_all_args) handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize( prepare_finalize = DeepEPHTPrepareAndFinalize(
handle, handle,
......
...@@ -939,10 +939,6 @@ class FusedMoEParallelConfig: ...@@ -939,10 +939,6 @@ class FusedMoEParallelConfig:
def use_all2all_kernels(self): def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "pplx"
@property @property
def use_deepep_ht_kernels(self): def use_deepep_ht_kernels(self):
return ( return (
...@@ -962,7 +958,7 @@ class FusedMoEParallelConfig: ...@@ -962,7 +958,7 @@ class FusedMoEParallelConfig:
@property @property
def use_batched_activation_format(self): def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels return self.use_deepep_ll_kernels
@property @property
def use_naive_all2all_kernels(self): def use_naive_all2all_kernels(self):
...@@ -1221,10 +1217,6 @@ class FusedMoEConfig: ...@@ -1221,10 +1217,6 @@ class FusedMoEConfig:
def use_ep(self): def use_ep(self):
return self.moe_parallel_config.use_ep return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property @property
def use_deepep_ht_kernels(self): def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels return self.moe_parallel_config.use_deepep_ht_kernels
......
...@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8( ...@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
ops.get_cutlass_pplx_moe_mm_data( ops.get_cutlass_batched_moe_mm_data(
expert_offsets, expert_offsets,
problem_sizes1, problem_sizes1,
problem_sizes2, problem_sizes2,
......
...@@ -493,7 +493,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -493,7 +493,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
""" """
A reference prepare/finalize class that reorganizes the tokens into A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use. that the batched dispatch/combine kernels use.
""" """
def __init__( def __init__(
...@@ -648,7 +648,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -648,7 +648,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
A reference MoE expert class that operates on expert batched format, A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use. dispatch/combine kernels use.
""" """
...@@ -880,7 +880,7 @@ def batched_moe_kernel_quantize_input( ...@@ -880,7 +880,7 @@ def batched_moe_kernel_quantize_input(
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
A Triton based MoE expert class that operates on expert batched format, A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use. dispatch/combine kernels use.
""" """
......
...@@ -1172,9 +1172,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1172,9 +1172,9 @@ class FusedMoEModularKernel(torch.nn.Module):
# This happens when none of the tokens from the all2all reach this # This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph # EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput # incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx # kernels. CUDAGraph compatible all2all kernels like the DeepEP
# kernels and the DeepEP low-latency kernels are always batched # low-latency kernels are always batched and can never run into
# and can never run into the tensor.numel() == 0 case. # the tensor.numel() == 0 case.
if M_full == 0: if M_full == 0:
assert num_chunks == 0 assert num_chunks == 0
workspace13 = None workspace13 = None
......
...@@ -143,10 +143,7 @@ def select_nvfp4_moe_backend( ...@@ -143,10 +143,7 @@ def select_nvfp4_moe_backend(
# NOTE(rob): this is kind of a hack. We need to peak into # NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using # the prepare-finalize selection to determine if we are using
# the batched or standard expert format. # the batched or standard expert format.
use_batched = ( use_batched = config.moe_parallel_config.use_deepep_ll_kernels
config.moe_parallel_config.use_deepep_ll_kernels
or config.moe_parallel_config.use_pplx_kernels
)
activation_format = ( activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts mk.FusedMoEActivationFormat.BatchedExperts
if use_batched if use_batched
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape,
moe_kernel_quantize_input,
)
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
block_shape: list[int] | None,
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
# For blocked per token: set to
# cdiv(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
if per_act_token_quant:
# per-token (M x 1)
assert block_shape is None
hidden_scale_bytes = elem_size
elif block_shape is not None:
# per-group (M x K_tiles)
block_size = block_shape[1]
num_blocks = cdiv(hidden_dim, block_size)
hidden_scale_bytes = num_blocks * elem_size
else:
# per-tensor (1 x 1)
hidden_scale_bytes = elem_size
else:
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
hidden_scale_bytes = 0
return (
round_up(hidden_dim_bytes, align),
round_up(hidden_scale_bytes, align),
)
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""PPLX-based prepare and finalize for expert parallelism."""
def __init__(
self,
a2a: pplx.AllToAll,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
):
super().__init__()
assert max_num_tokens > 0
assert num_local_experts > 0
self.a2a = a2a
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.uint32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert topk_ids.size(0) == num_tokens
# expert_map should be None because with expert map, -1 id is used for
# non-local token; this causes error when casting ids to the
# topk_indices_dtype() int32
#
if expert_map is not None:
logger.warning_once(
"The PPLX backend does not support expert mapping. "
"The provided `expert_map` will be ignored."
)
expert_map = None # noqa: F841
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
_validate_scale_shape(
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
)
orig_a_scale_block_shape: int | None = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
# pplx requires 2-d scales even for scalar scales
if a1q_scale.dim() <= 1:
assert scalar_scales
a1q_scale = a1q_scale.view(1, 1)
orig_a_scale_block_shape = a1q_scale.shape[-1]
if not quant_config.is_block_quantized:
# TODO (bnell): use group_broadcast instead?
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
assert a1q_scale is None or a1q_scale.ndim == 2, (
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
)
expert_num_tokens = torch.empty(
self.num_local_experts,
dtype=torch.int32,
device=device,
)
expert_x = torch.empty(
(
self.num_local_experts,
self.max_num_tokens * self.num_dispatchers(),
hidden_dim,
),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: torch.Tensor | None = None
if a1q.dtype.itemsize == 1:
if quant_config.is_per_act_token:
# (M x 1) -> (E x M x K)
final_dim = expert_x.size(2)
elif quant_config.is_per_tensor:
# (1 x 1) -> (E x 1 x 1)
final_dim = 1
else:
# (M x K_tiles) -> (E x M x K_tiles)
assert quant_config.block_shape is not None
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
final_dim = num_blocks
expert_x_scale_shape = (
self.num_local_experts,
expert_x.size(1),
round_up(final_dim, 4), # round up for alignment
)
expert_x_scale = torch.empty(
expert_x_scale_shape,
dtype=torch.float32,
device=expert_x.device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: torch.Tensor | None = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
return (
hook,
lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
orig_a_scale_block_shape,
),
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: torch.Tensor | None,
orig_a_scale_block_shape: int | None,
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant=defer_input_quant,
)
hook()
return receiver()
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
"Weight application and reduction happens in the combine kernel."
)
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: torch.Tensor | None = None
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
# num_tokens = output.size(0) # M
# assert topk_ids.size(0) == num_tokens, (
# f"{topk_ids.size(0)} == {num_tokens}")
assert topk_ids.size() == topk_weights.size(), (
f"{topk_ids.size()} == {topk_weights.size()}"
)
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}"
)
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()
...@@ -216,8 +216,7 @@ class DefaultMoERunner(MoERunner): ...@@ -216,8 +216,7 @@ class DefaultMoERunner(MoERunner):
@property @property
def use_dp_chunking(self) -> bool: def use_dp_chunking(self) -> bool:
return ( return (
self.moe_config.moe_parallel_config.use_pplx_kernels self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
......
...@@ -14,10 +14,11 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): ...@@ -14,10 +14,11 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
implementation does not perform weight application and reduction implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize but cannot address the needs of all the compatible PrepareAndFinalize
implementations. implementations.
For example, BatchedTritonExperts is compatible with both For example, BatchedTritonExperts is compatible with both batched
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
does the weight-application + reduction as part of the pplx combine kernel. BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
But the BatchedPrepareAndFinalize needs an implementation. To facilitate the weight-application + reduction as part of the combine kernel, while
BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to so the PrepareAndFinalize implementations could choose how to
weight + reduce. weight + reduce.
......
...@@ -798,7 +798,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -798,7 +798,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# batched activation format. As self.fused_experts is not # batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config # initialized at this point, we resort to checking the MoE config
# directly. # directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe: if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else: else:
......
...@@ -402,11 +402,6 @@ def _has_module(module_name: str) -> bool: ...@@ -402,11 +402,6 @@ def _has_module(module_name: str) -> bool:
return importlib.util.find_spec(module_name) is not None return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool: def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available.""" """Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep") return _has_module("deep_ep")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment