Unverified Commit cfa666ac authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Update PyTorch comm API (#100)



Use updated comm API PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ed1a3116
......@@ -127,7 +127,7 @@ def gather_split_1d_tensor(
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(gathered, tensor, group=tp_group)
torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
return gathered
......@@ -346,7 +346,7 @@ def reduce_scatter_along_first_dim(
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._reduce_scatter_base(
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
......@@ -368,7 +368,7 @@ def gather_along_first_dim(
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._all_gather_base(
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
......@@ -391,7 +391,7 @@ def gather_along_last_dim(
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._all_gather_base(
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment