Unverified Commit c910eeb1 authored by YiSheng5's avatar YiSheng5 Committed by GitHub
Browse files

[XPU]Bug fix for some unexpected error when use AgRs backend on XPU device. (#36593)


Signed-off-by: default avataryisheng <yi.sheng@intel.com>
parent f4ae58b3
......@@ -70,7 +70,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......@@ -103,9 +103,9 @@ class XpuCommunicator(DeviceCommunicatorBase):
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
dist.reduce_scatter(output, input_splits, group=self.device_group)
else:
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......@@ -149,10 +149,10 @@ class XpuCommunicator(DeviceCommunicatorBase):
device=input_.device,
)
)
dist.all_gather(all_gather_list, input_)
dist.all_gather(all_gather_list, input_, group=self.device_group)
output_tensor = torch.cat(all_gather_list, dim=0)
else:
dist.all_gather([output_tensor], input_)
dist.all_gather([output_tensor], input_, group=self.device_group)
return output_tensor
if isinstance(input_, torch.Tensor):
......
......@@ -85,6 +85,9 @@ class XPUWorker(Worker):
current_platform.dist_backend,
)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
# Set random seed.
set_random_seed(self.model_config.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment