Unverified Commit 2aab9acf authored by Fadi Arafeh's avatar Fadi Arafeh Committed by GitHub
Browse files

[CPU][BugFix] Fix inter-node pipeline parallel (#40150)


Signed-off-by: default avatarFadi Arafeh <fadi.arafeh@arm.com>
parent 58631d7c
...@@ -45,6 +45,9 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -45,6 +45,9 @@ class CpuCommunicator(DeviceCommunicatorBase):
unique_name, unique_name,
) )
# send/recv tensor_dict is only supported through the SHM communicator backend
self.supports_tensor_dict = isinstance(self.dist_module, _CPUSHMDistributed)
if self.use_all2all: if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type] if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning( logger.warning(
...@@ -143,12 +146,22 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -143,12 +146,22 @@ class CpuCommunicator(DeviceCommunicatorBase):
tensor_dict: dict[str, torch.Tensor | Any], tensor_dict: dict[str, torch.Tensor | Any],
dst: int, dst: int,
) -> None: ) -> None:
if not self.supports_tensor_dict:
raise NotImplementedError(
"CpuCommunicator does not support tensor dict fastpath with "
"torch.distributed backend."
)
return self.dist_module.send_tensor_dict(tensor_dict, dst) return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict( def recv_tensor_dict(
self, self,
src: int, src: int,
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
if not self.supports_tensor_dict:
raise NotImplementedError(
"CpuCommunicator does not support tensor dict fastpath with "
"torch.distributed backend."
)
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch_router_logits( def dispatch_router_logits(
......
...@@ -394,8 +394,10 @@ class GroupCoordinator: ...@@ -394,8 +394,10 @@ class GroupCoordinator:
current_platform.is_tpu() or current_platform.use_custom_op_collectives() current_platform.is_tpu() or current_platform.use_custom_op_collectives()
) )
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( self.use_cpu_custom_send_recv = (
torch.ops._C, "init_shm_manager" current_platform.is_cpu()
and self.device_communicator
and getattr(self.device_communicator, "supports_tensor_dict", False)
) )
def create_mq_broadcaster( def create_mq_broadcaster(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment