"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "769718dfcf6fb1d861e6e2e35b3d92627c5ccc78"
Unverified Commit a609b4f0 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Bugfix] Fix #3291 (#3333)



* Fix #3291

* update

* fix

* Unit key

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 983a4fdd
...@@ -35,6 +35,7 @@ class SparseGradOptimizer(abc.ABC): ...@@ -35,6 +35,7 @@ class SparseGradOptimizer(abc.ABC):
# otherwise it will crash the training # otherwise it will crash the training
self.shmem_buffer_holder = [] self.shmem_buffer_holder = []
assert len(params) > 0, 'Empty parameters'
# if we are using shared memory for communication # if we are using shared memory for communication
for emb in params: for emb in params:
assert isinstance(emb, NodeEmbedding), \ assert isinstance(emb, NodeEmbedding), \
...@@ -50,6 +51,7 @@ class SparseGradOptimizer(abc.ABC): ...@@ -50,6 +51,7 @@ class SparseGradOptimizer(abc.ABC):
'MultiGPU world_size for each embedding should be same.' 'MultiGPU world_size for each embedding should be same.'
assert not self._rank is None assert not self._rank is None
assert not self._world_size is None assert not self._world_size is None
self._nccl_root_id = 'SparseGradOptimizer.nccl_root_id'
def step(self): def step(self):
''' The step function. ''' The step function.
...@@ -111,17 +113,14 @@ class SparseGradOptimizer(abc.ABC): ...@@ -111,17 +113,14 @@ class SparseGradOptimizer(abc.ABC):
# root process broadcasts nccl id # root process broadcasts nccl id
nccl_id = nccl.UniqueId() nccl_id = nccl.UniqueId()
uid = str(nccl_id) uid = str(nccl_id)
store.set('nccl_root_id', uid) store.set(self._nccl_root_id, uid)
else: else:
uid = store.get('nccl_root_id') uid = store.get(self._nccl_root_id)
nccl_id = nccl.UniqueId(uid) nccl_id = nccl.UniqueId(uid)
# needs to be set for nccl to work # needs to be set for nccl to work
self._comm = nccl.Communicator(self._world_size, self._comm = nccl.Communicator(self._world_size,
self._rank, self._rank,
nccl_id) nccl_id)
if self._rank == 0:
# clear the store entry for future communicators
store.delete_key('nccl_root_id')
th.distributed.barrier() th.distributed.barrier()
def _shared_setup(self): def _shared_setup(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment