Unverified Commit 2aab9acf authored by Fadi Arafeh's avatar Fadi Arafeh Committed by GitHub
Browse files

[CPU][BugFix] Fix inter-node pipeline parallel (#40150)


Signed-off-by: default avatarFadi Arafeh <fadi.arafeh@arm.com>
parent 58631d7c
......@@ -45,6 +45,9 @@ class CpuCommunicator(DeviceCommunicatorBase):
unique_name,
)
# send/recv tensor_dict is only supported through the SHM communicator backend
self.supports_tensor_dict = isinstance(self.dist_module, _CPUSHMDistributed)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
......@@ -143,12 +146,22 @@ class CpuCommunicator(DeviceCommunicatorBase):
tensor_dict: dict[str, torch.Tensor | Any],
dst: int,
) -> None:
if not self.supports_tensor_dict:
raise NotImplementedError(
"CpuCommunicator does not support tensor dict fastpath with "
"torch.distributed backend."
)
return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict(
self,
src: int,
) -> dict[str, torch.Tensor | Any]:
if not self.supports_tensor_dict:
raise NotImplementedError(
"CpuCommunicator does not support tensor dict fastpath with "
"torch.distributed backend."
)
return self.dist_module.recv_tensor_dict(src)
def dispatch_router_logits(
......
......@@ -394,8 +394,10 @@ class GroupCoordinator:
current_platform.is_tpu() or current_platform.use_custom_op_collectives()
)
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
torch.ops._C, "init_shm_manager"
self.use_cpu_custom_send_recv = (
current_platform.is_cpu()
and self.device_communicator
and getattr(self.device_communicator, "supports_tensor_dict", False)
)
def create_mq_broadcaster(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment