Unverified Commit cfa666ac authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Update PyTorch comm API (#100)



Use updated comm API PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ed1a3116
...@@ -127,7 +127,7 @@ def gather_split_1d_tensor( ...@@ -127,7 +127,7 @@ def gather_split_1d_tensor(
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
torch.distributed._all_gather_base(gathered, tensor, group=tp_group) torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
return gathered return gathered
...@@ -346,7 +346,7 @@ def reduce_scatter_along_first_dim( ...@@ -346,7 +346,7 @@ def reduce_scatter_along_first_dim(
output = torch.empty( output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device() dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
) )
handle = torch.distributed._reduce_scatter_base( handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op output, input_.contiguous(), group=tp_group, async_op=async_op
) )
return output, handle return output, handle
...@@ -368,7 +368,7 @@ def gather_along_first_dim( ...@@ -368,7 +368,7 @@ def gather_along_first_dim(
output = torch.empty( output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device() dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
) )
handle = torch.distributed._all_gather_base( handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op output, input_.contiguous(), group=tp_group, async_op=async_op
) )
...@@ -391,7 +391,7 @@ def gather_along_last_dim( ...@@ -391,7 +391,7 @@ def gather_along_last_dim(
output = torch.empty( output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device() dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
) )
handle = torch.distributed._all_gather_base( handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op output, input_.contiguous(), group=tp_group, async_op=async_op
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment