Unverified Commit 8609e637 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix All-Gather under world size one (#7219)

parent dea2b84b
...@@ -523,17 +523,25 @@ class GroupCoordinator: ...@@ -523,17 +523,25 @@ class GroupCoordinator:
self, self,
input_: torch.Tensor, input_: torch.Tensor,
dim: int = -1, dim: int = -1,
tensor_list: List[torch.Tensor] = None, output_tensor_list: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ if output_tensor_list is not None:
logger.warning(
"Performing in-place all-gather with a group size of 1. "
"This may be unnecessary; consider bypassing it for better efficiency."
)
output_tensor_list[0].copy_(input_)
return None
else:
return input_
if tensor_list is not None: if output_tensor_list is not None:
# TODO(ch-wan): support other backends # TODO(ch-wan): support other backends
return torch.distributed.all_gather( return torch.distributed.all_gather(
tensor_list, input_, group=self.device_group output_tensor_list, input_, group=self.device_group
) )
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment