"tests/vscode:/vscode.git/clone" did not exist on "98060b001dfae385c73d2380ad6a38456cbf42c9"
Unverified Commit c7c80af0 authored by yzds's avatar yzds Committed by GitHub
Browse files

fix pynccl reduce_scatter (#23648)


Co-authored-by: default avatarhongchao <hongchao@msh.team>
parent 6891205b
......@@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......@@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment