Unverified Commit edadca10 authored by kzwrime's avatar kzwrime Committed by GitHub
Browse files

[Bugfix] Add CpuCommunicator.dispatch and combine to fix DP+MoE inference (#31867)


Signed-off-by: default avatarkunzh <zhikun.wu@outlook.com>
parent d86fc23b
...@@ -286,7 +286,10 @@ class DeviceCommunicatorBase: ...@@ -286,7 +286,10 @@ class DeviceCommunicatorBase:
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
......
...@@ -8,11 +8,14 @@ import torch ...@@ -8,11 +8,14 @@ import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle from vllm.distributed.utils import pickle
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CpuCommunicator(DeviceCommunicatorBase): class CpuCommunicator(DeviceCommunicatorBase):
def __init__( def __init__(
...@@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
): ):
self.dist_module = _CPUSHMDistributed(self) self.dist_module = _CPUSHMDistributed(self)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on CPU. "
"Falling back to `naive` all2all manager for CPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_): def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group) self.dist_module.all_reduce(input_, group=self.device_group)
return input_ return input_
...@@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator): def __init__(self, communicator: CpuCommunicator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment