"examples/multimodal/vscode:/vscode.git/clone" did not exist on "d22d9e761e6e9a569491654eea5fa439d3904601"
Unverified Commit 993859ce authored by Chaojun Zhang's avatar Chaojun Zhang Committed by GitHub
Browse files

[XPU] fix all_reduce all-zero accuracy issue under torch.compile (#39844)


Signed-off-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 48a65ccb
......@@ -47,9 +47,10 @@ class XpuCommunicator(DeviceCommunicatorBase):
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")
def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
output = input_.clone() if torch.compiler.is_compiling() else input_
dist.all_reduce(output, group=self.device_group)
return output
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment