Unverified Commit 73594814 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Feature][GPU] Add function for setting weights of a sparse embedding on multiple GPUs. (#3047)



* add unit test

* Extend NDArrayPartition object

* Add method for setting embedding, and improve documentation

* Sync before returning

* Use name unique to sparse embedding class to avoid delete
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 70af1945
...@@ -123,16 +123,11 @@ class NodeEmbedding: # NodeEmbedding ...@@ -123,16 +123,11 @@ class NodeEmbedding: # NodeEmbedding
if rank == 0: if rank == 0:
# root process broadcasts nccl id # root process broadcasts nccl id
nccl_id = nccl.UniqueId() nccl_id = nccl.UniqueId()
self._store.set('nccl_root_id', str(nccl_id)) self._store.set('nccl_root_id_sparse_emb', str(nccl_id))
else: else:
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id')) nccl_id = nccl.UniqueId(self._store.get('nccl_root_id_sparse_emb'))
_COMM = nccl.Communicator(self._world_size, self._rank, _COMM = nccl.Communicator(self._world_size, self._rank,
nccl_id) nccl_id)
if self._rank == 0:
# clear the store entry for future communicators
self._store.delete_key('nccl_root_id')
th.distributed.barrier()
self._comm = _COMM self._comm = _COMM
if not self._partition: if not self._partition:
...@@ -335,12 +330,43 @@ class NodeEmbedding: # NodeEmbedding ...@@ -335,12 +330,43 @@ class NodeEmbedding: # NodeEmbedding
""" """
return self._tensor return self._tensor
def gather_embedding(self): def all_set_embedding(self, values):
"""Return a copy of the embedding stored in CPU memory. If this is a """ Set the values of the embedding. This method must be called by all
processes sharing the embedding with identical tensors for
:attr:`values`.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Parameters
----------
values : Tensor
The global tensor to pull values from.
"""
if self._partition:
idxs = F.copy_to(
self._partition.get_local_indices(
self._comm.rank(),
ctx=F.context(self._tensor)),
F.context(values))
self._tensor[:] = F.copy_to(F.gather_row(values, idxs),
ctx=F.context(self._tensor))[:]
else:
if self._rank == 0:
self._tensor[:] = F.copy_to(values,
ctx=F.context(self._tensor))[:]
if th.distributed.is_initialized():
th.distributed.barrier()
def all_get_embedding(self):
""" Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared multi-processing instance, the tensor will be returned in shared
memory. If the embedding is currently stored on multiple GPUs, all memory. If the embedding is currently stored on multiple GPUs, all
processes must call this method in the same order. processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Returns Returns
------- -------
torch.Tensor torch.Tensor
......
...@@ -419,13 +419,29 @@ class NDArrayPartition(object): ...@@ -419,13 +419,29 @@ class NDArrayPartition(object):
array_size, num_parts) array_size, num_parts)
else: else:
assert False, 'Unknown partition mode "{}"'.format(mode) assert False, 'Unknown partition mode "{}"'.format(mode)
self._array_size = array_size
self._num_parts = num_parts
def num_parts(self):
""" Get the number of partitions.
"""
return self._num_parts
def array_size(self):
""" Get the total size of the first dimension of the partitioned array.
"""
return self._array_size
def get(self): def get(self):
""" Get the C-handle for this object. """ Get the C-handle for this object.
""" """
return self._partition return self._partition
def get_local_indices(self, part, ctx):
""" Get the set of global indices in this given partition.
"""
return self.map_to_global(F.arange(0, self.local_size(part), ctx=ctx), part)
def local_size(self, part): def local_size(self, part):
""" Get the number of rows/items assigned to the given part. """ Get the number of rows/items assigned to the given part.
""" """
......
import multiprocessing as mp
import unittest, os
import pytest
import torch as th
import backend as F
from dgl.nn import NodeEmbedding
def initializer(emb):
th.manual_seed(0)
emb.uniform_(-1.0, 1.0)
return emb
def check_all_set_all_get_func(device, init_emb):
num_embs = init_emb.shape[0]
emb_dim = init_emb.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', device=device)
dgl_emb.all_set_embedding(init_emb)
out_emb = dgl_emb.all_get_embedding()
assert F.allclose(init_emb, out_emb)
def start_sparse_worker(rank, world_size, test, args):
print('start sparse worker {}'.format(rank))
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
backend = 'gloo'
device = F.ctx()
if device.type == 'cuda':
device = th.device(rank)
th.cuda.set_device(device)
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=rank)
test(device, *args)
th.distributed.barrier()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set(num_workers):
if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = []
init_emb = th.rand([1000, 8])
ctx = mp.get_context('spawn')
for i in range(num_workers):
p = ctx.Process(target=start_sparse_worker,
args=(i, num_workers, check_all_set_all_get_func, (init_emb,)))
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
for p in worker_list:
assert p.exitcode == 0
if __name__ == '__main__':
test_sparse_emb_get_set(1)
test_sparse_emb_get_set(2)
test_sparse_emb_get_set(3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment