Unverified Commit 2ea50e97 authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Enable Allgather/ReduceScatter backend for NaiveAllToAll (#23964)


Signed-off-by: default avatarShu Wang. <shuw@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarShu Wang <shuw@nvidia.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent b419937c
...@@ -5,6 +5,7 @@ from typing import Any ...@@ -5,6 +5,7 @@ from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx from vllm.utils import has_deep_ep, has_pplx
...@@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase):
pass pass
class AgRsAll2AllManager(All2AllManagerBase):
"""
An implementation of all2all communication based on
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase): class PPLXAll2AllManager(All2AllManagerBase):
""" """
All2All communication based on PPLX kernels. All2All communication based on PPLX kernels.
......
...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import NaiveAll2AllManager from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group) self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.") logger.info("Using naive all2all manager.")
elif all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AllGather-ReduceScatter all2all manager.")
elif all2all_backend == "pplx": elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group) self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
......
...@@ -149,8 +149,11 @@ if TYPE_CHECKING: ...@@ -149,8 +149,11 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput", VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
"deepep_low_latency"] = "naive" "deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"] = \
"allgather_reducescatter"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_SLEEP_WHEN_IDLE: bool = False
...@@ -1124,14 +1127,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1124,14 +1127,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
# all2all backend for vllm's expert parallel communication # all2all backend for vllm's expert parallel communication
# Available options: # Available options:
# - "naive": naive all2all implementation using all-reduce # - "naive": naive all2all implementation using broadcasts
# - "allgather_reducescatter": all2all implementation based on allgather and
# reducescatter
# - "pplx": use pplx kernels # - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels # - "deepep_low_latency", use deepep low-latency kernels
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
env_with_choices("VLLM_ALL2ALL_BACKEND", "naive", env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
["naive", "pplx", ["naive", "pplx",
"deepep_high_throughput", "deepep_low_latency"]), "deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"]),
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
# Both require compute capability 10.0 or above. # Both require compute capability 10.0 or above.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment