Unverified Commit 5dfaf99e authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Performance] Add NCCL support (#5929)


Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent e594b4a8
......@@ -12,7 +12,7 @@ from .... import backend as F
from ...dist_tensor import DistTensor
from ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY
from ...nn.pytorch import DistEmbedding
from .utils import alltoall_cpu, alltoallv_cpu
from .utils import alltoall, alltoallv
EMB_STATES = "emb_states"
WORLD_SIZE = "world_size"
......@@ -256,9 +256,13 @@ class DistSparseGradOptimizer(abc.ABC):
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
"""
with th.no_grad():
device = (
th.device(f"cuda:{self._rank}")
if th.distributed.get_backend() == "nccl"
else th.device("cpu")
)
local_indics = {emb.name: [] for emb in self._params}
local_grads = {emb.name: [] for emb in self._params}
device = th.device("cpu")
for emb in self._params:
name = emb.weight.name
kvstore = emb.weight.kvstore
......@@ -310,7 +314,11 @@ class DistSparseGradOptimizer(abc.ABC):
if trainers_per_server <= 1:
idx_split_size.append(
th.tensor([idx_i.shape[0]], dtype=th.int64)
th.tensor(
[idx_i.shape[0]],
dtype=th.int64,
device=device,
)
)
idics_list.append(idx_i)
grad_list.append(grad_i)
......@@ -323,7 +331,11 @@ class DistSparseGradOptimizer(abc.ABC):
idx_j = idx_i[mask]
grad_j = grad_i[mask]
idx_split_size.append(
th.tensor([idx_j.shape[0]], dtype=th.int64)
th.tensor(
[idx_j.shape[0]],
dtype=th.int64,
device=device,
)
)
idics_list.append(idx_j)
grad_list.append(grad_j)
......@@ -336,39 +348,45 @@ class DistSparseGradOptimizer(abc.ABC):
# Note: If we have GPU nccl support, we can use all_to_all to
# sync information here
gather_list = list(
th.empty([self._world_size], dtype=th.int64).chunk(
self._world_size
)
th.empty(
[self._world_size], dtype=th.int64, device=device
).chunk(self._world_size)
)
alltoall_cpu(
alltoall(
self._rank,
self._world_size,
gather_list,
idx_split_size,
device,
)
# use cpu until we have GPU alltoallv
idx_gather_list = [
th.empty((int(num_emb),), dtype=idics.dtype)
th.empty(
(int(num_emb),), dtype=idics.dtype, device=device
)
for num_emb in gather_list
]
alltoallv_cpu(
alltoallv(
self._rank,
self._world_size,
idx_gather_list,
idics_list,
device,
)
local_indics[name] = idx_gather_list
grad_gather_list = [
th.empty(
(int(num_emb), grads.shape[1]), dtype=grads.dtype
(int(num_emb), grads.shape[1]),
dtype=grads.dtype,
device=device,
)
for num_emb in gather_list
]
alltoallv_cpu(
alltoallv(
self._rank,
self._world_size,
grad_gather_list,
grad_list,
device,
)
local_grads[name] = grad_gather_list
else:
......
......@@ -13,7 +13,7 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
rank : int
The rank of current worker
world_size : int
The size of the entire
The size of the entire communicator
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
......@@ -37,7 +37,7 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
rank : int
The rank of current worker
world_size : int
The size of the entire
The size of the entire communicator
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
......@@ -60,3 +60,65 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
dist.recv(output_tensor_list[i], src=i)
th.distributed.barrier()
def alltoall(rank, world_size, output_tensor_list, input_tensor_list, device):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape.
Parameters
----------
rank : int
The rank of current worker
world_size : int
The size of the entire communicator
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
device: th.device
Device of the tensors
"""
if th.distributed.get_backend() == "nccl":
input_tensor_list = [
tensor.to(th.device(device)) for tensor in input_tensor_list
]
th.distributed.all_to_all(output_tensor_list, input_tensor_list)
else:
alltoall_cpu(
rank,
world_size,
output_tensor_list,
input_tensor_list,
)
def alltoallv(rank, world_size, output_tensor_list, input_tensor_list, device):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list.
Parameters
----------
rank : int
The rank of current worker
world_size : int
The size of the entire communicator
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
device: th.device
Device of the tensors
"""
if th.distributed.get_backend() == "nccl":
input_tensor_list = [
tensor.to(th.device(device)) for tensor in input_tensor_list
]
th.distributed.all_to_all(output_tensor_list, input_tensor_list)
else:
alltoallv_cpu(
rank,
world_size,
output_tensor_list,
input_tensor_list,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment