Unverified Commit 62f8eb48 authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Fix fallback allgather issue (#8041)

parent b7cd7430
...@@ -650,17 +650,19 @@ class GroupCoordinator: ...@@ -650,17 +650,19 @@ class GroupCoordinator:
output_size, dtype=input_.dtype, device=input_.device output_size, dtype=input_.dtype, device=input_.device
) )
# All-gather.
if input_.is_cpu and is_shm_available(
input_.dtype, self.world_size, self.local_size
):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
if input_.is_cpu: if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size): torch.distributed.all_gather_into_tensor(
return torch.ops.sgl_kernel.shm_allgather(input_, dim) output_tensor, input_, group=self.device_group
else: )
torch.distributed.all_gather_into_tensor( else:
output_tensor, input_, group=self.device_group self.all_gather_into_tensor(output_tensor, input_)
)
return output_tensor
# All-gather.
self.all_gather_into_tensor(output_tensor, input_)
# Reshape # Reshape
output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.movedim(0, dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment