Unverified Commit c7c80af0 authored by yzds's avatar yzds Committed by GitHub
Browse files

fix pynccl reduce_scatter (#23648)


Co-authored-by: default avatarhongchao <hongchao@msh.team>
parent 6891205b
...@@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype, dtype=input_tensor.dtype,
device=input_tensor.device) device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
...@@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device) device=input_tensor.device)
if sizes is not None: if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else: else:
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment