"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d57ff78da11193fbbee7f37a69fcfe1c14da2ae4"
Unverified Commit c37e0364 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Bug Fix] Fix sparse opt bug (#2859)



* Fix #2856

* upd

* Fix unitest

* upd

* upd

* upd

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-57-25.ec2.internal>
parent 3c387988
...@@ -171,6 +171,17 @@ class NodeEmbedding: # NodeEmbedding ...@@ -171,6 +171,17 @@ class NodeEmbedding: # NodeEmbedding
""" """
return self._num_embeddings return self._num_embeddings
@property
def embedding_dim(self):
"""Return the dimension of embeddings.
Returns
-------
int
The dimension of embeddings.
"""
return self._embedding_dim
def set_optm_state(self, state): def set_optm_state(self, state):
"""Store the optimizer related state tensor. """Store the optimizer related state tensor.
...@@ -214,6 +225,19 @@ class NodeEmbedding: # NodeEmbedding ...@@ -214,6 +225,19 @@ class NodeEmbedding: # NodeEmbedding
def emb_tensor(self): def emb_tensor(self):
"""Return the tensor storing the node embeddings """Return the tensor storing the node embeddings
DEPRECATED: renamed weight
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return self._tensor
@property
def weight(self):
"""Return the tensor storing the node embeddings
Returns Returns
------- -------
torch.Tensor torch.Tensor
......
...@@ -77,8 +77,17 @@ class SparseGradOptimizer(abc.ABC): ...@@ -77,8 +77,17 @@ class SparseGradOptimizer(abc.ABC):
for i, data in emb._trace: for i, data in emb._trace:
idx.append(i) idx.append(i)
grad.append(data.grad.data) grad.append(data.grad.data)
idx = th.cat(idx, dim=0) # If the sparse embedding is not used in the previous forward step
grad = th.cat(grad, dim=0) # The idx and grad will be empty, initialize them as empty tensors to
# avoid crashing the optimizer step logic.
#
# Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding
# to certain embedding to this process.
idx = th.cat(idx, dim=0) if len(idx) != 0 else \
th.zeros((0,), dtype=th.long, device=th.device('cpu'))
grad = th.cat(grad, dim=0) if len(grad) != 0 else \
th.zeros((0, emb.embedding_dim), dtype=th.float32, device=th.device('cpu'))
device = grad.device device = grad.device
idx_dtype = idx.dtype idx_dtype = idx.dtype
...@@ -164,6 +173,8 @@ class SparseGradOptimizer(abc.ABC): ...@@ -164,6 +173,8 @@ class SparseGradOptimizer(abc.ABC):
for emb in self._params: # pylint: disable=too-many-nested-blocks for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name emb_name = emb.name
if self._world_size > 1: if self._world_size > 1:
# The first element in shared_emb[emb_name][0] is the local idx
device = shared_emb[emb_name][0][0].device
# gather gradients from all other processes # gather gradients from all other processes
for i in range(self._world_size): for i in range(self._world_size):
if i != self._rank: if i != self._rank:
...@@ -277,7 +288,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -277,7 +288,7 @@ class SparseAdagrad(SparseGradOptimizer):
if self._rank <= 0: if self._rank <= 0:
emb_name = emb.name emb_name = emb.name
state = create_shared_mem_array(emb_name+'_state', \ state = create_shared_mem_array(emb_name+'_state', \
emb.emb_tensor.shape, th.float32).zero_() emb.weight.shape, th.float32).zero_()
if self._rank == 0: if self._rank == 0:
if self._world_size > 1: if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name) emb.store.set(emb_name+'_opt', emb_name)
...@@ -286,7 +297,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -286,7 +297,7 @@ class SparseAdagrad(SparseGradOptimizer):
emb_name = emb.name emb_name = emb.name
emb.store.wait([emb_name+'_opt']) emb.store.wait([emb_name+'_opt'])
state = get_shared_mem_array(emb_name+'_state', \ state = get_shared_mem_array(emb_name+'_state', \
emb.emb_tensor.shape, th.float32) emb.weight.shape, th.float32)
emb.set_optm_state(state) emb.set_optm_state(state)
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
...@@ -322,7 +333,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -322,7 +333,7 @@ class SparseAdagrad(SparseGradOptimizer):
std_values = grad_state.add_(eps).sqrt_() std_values = grad_state.add_(eps).sqrt_()
tmp = clr * grad_values / std_values tmp = clr * grad_values / std_values
emb.emb_tensor[state_idx] -= tmp.to(state_dev) emb.weight[state_idx] -= tmp.to(state_dev)
class SparseAdam(SparseGradOptimizer): class SparseAdam(SparseGradOptimizer):
r''' Node embedding optimizer using the Adam algorithm. r''' Node embedding optimizer using the Adam algorithm.
...@@ -383,11 +394,11 @@ class SparseAdam(SparseGradOptimizer): ...@@ -383,11 +394,11 @@ class SparseAdam(SparseGradOptimizer):
if self._rank <= 0: if self._rank <= 0:
emb_name = emb.name emb_name = emb.name
state_step = create_shared_mem_array(emb_name+'_step', \ state_step = create_shared_mem_array(emb_name+'_step', \
(emb.emb_tensor.shape[0],), th.float32).zero_() (emb.weight.shape[0],), th.float32).zero_()
state_mem = create_shared_mem_array(emb_name+'_mem', \ state_mem = create_shared_mem_array(emb_name+'_mem', \
emb.emb_tensor.shape, th.float32).zero_() emb.weight.shape, th.float32).zero_()
state_power = create_shared_mem_array(emb_name+'_power', \ state_power = create_shared_mem_array(emb_name+'_power', \
emb.emb_tensor.shape, th.float32).zero_() emb.weight.shape, th.float32).zero_()
if self._rank == 0: if self._rank == 0:
emb_name = emb.name emb_name = emb.name
if self._world_size > 1: if self._world_size > 1:
...@@ -397,11 +408,11 @@ class SparseAdam(SparseGradOptimizer): ...@@ -397,11 +408,11 @@ class SparseAdam(SparseGradOptimizer):
emb_name = emb.name emb_name = emb.name
emb.store.wait([emb_name+'_opt']) emb.store.wait([emb_name+'_opt'])
state_step = get_shared_mem_array(emb_name+'_step', \ state_step = get_shared_mem_array(emb_name+'_step', \
(emb.emb_tensor.shape[0],), th.float32) (emb.weight.shape[0],), th.float32)
state_mem = get_shared_mem_array(emb_name+'_mem', \ state_mem = get_shared_mem_array(emb_name+'_mem', \
emb.emb_tensor.shape, th.float32) emb.weight.shape, th.float32)
state_power = get_shared_mem_array(emb_name+'_power', \ state_power = get_shared_mem_array(emb_name+'_power', \
emb.emb_tensor.shape, th.float32) emb.weight.shape, th.float32)
state = (state_step, state_mem, state_power) state = (state_step, state_mem, state_power)
emb.set_optm_state(state) emb.set_optm_state(state)
...@@ -458,4 +469,4 @@ class SparseAdam(SparseGradOptimizer): ...@@ -458,4 +469,4 @@ class SparseAdam(SparseGradOptimizer):
state_step)).unsqueeze(1) state_step)).unsqueeze(1)
std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps) std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
emb.emb_tensor[state_idx] -= std_values.to(state_dev) emb.weight[state_idx] -= std_values.to(state_dev)
import time
import multiprocessing as mp
import unittest, os
import pytest
import torch as th import torch as th
import backend as F import backend as F
from dgl.nn import NodeEmbedding from dgl.nn import NodeEmbedding
from dgl.optim import SparseAdam, SparseAdagrad from dgl.optim import SparseAdam, SparseAdagrad
import unittest, os
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_sparse_adam(): def test_sparse_adam():
num_embs = 10 num_embs = 10
...@@ -16,7 +19,7 @@ def test_sparse_adam(): ...@@ -16,7 +19,7 @@ def test_sparse_adam():
th.manual_seed(0) th.manual_seed(0)
th.nn.init.uniform_(torch_emb.weight, 0, 1.0) th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
th.manual_seed(0) th.manual_seed(0)
th.nn.init.uniform_(dgl_emb.emb_tensor, 0, 1.0) th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01) dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01) torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
...@@ -36,11 +39,144 @@ def test_sparse_adam(): ...@@ -36,11 +39,144 @@ def test_sparse_adam():
dgl_adam.step() dgl_adam.step()
torch_adam.step() torch_adam.step()
assert F.allclose(dgl_emb.emb_tensor, torch_emb.weight) assert F.allclose(dgl_emb.weight, torch_emb.weight)
# Can not test second step # Can not test second step
# Pytorch sparseAdam maintains a global step # Pytorch sparseAdam maintains a global step
# DGL sparseAdam use a per embedding step # DGL sparseAdam use a per embedding step
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_sparse_adam_zero_step():
num_embs = 10
emb_dim = 4
device=F.ctx()
dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test')
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, 'test2')
torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
th.manual_seed(0)
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
th.manual_seed(0)
th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0)
dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
torch_adam = th.optim.SparseAdam(
list(torch_emb.parameters()) + list(torch_emb_zero.parameters()), lr=0.01)
# first step
idx = th.randint(0, num_embs, size=(4,))
dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
torch_value = torch_emb(idx)
labels = th.ones((4,)).long()
dgl_adam.zero_grad()
torch_adam.zero_grad()
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
dgl_loss.backward()
torch_loss.backward()
dgl_adam.step()
torch_adam.step()
assert F.allclose(dgl_emb.weight, torch_emb.weight)
def initializer(emb):
th.manual_seed(0)
emb.uniform_(-1.0, 1.0)
return emb
def start_sparse_adam_worker(rank, world_size, has_zero_grad=False, num_embs=128, emb_dim=10):
print('start sparse worker for adam {}'.format(rank))
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
backend = 'gloo'
device=F.ctx()
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=rank)
dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', init_func=initializer)
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
th.manual_seed(0)
th.nn.init.uniform_(torch_emb.weight, -1.0, 1.0)
torch_emb = th.nn.parallel.DistributedDataParallel(torch_emb)
if has_zero_grad:
dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, 'zero', init_func=initializer)
torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
th.manual_seed(0)
th.nn.init.uniform_(torch_emb_zero.weight, -1.0, 1.0)
torch_emb_zero = th.nn.parallel.DistributedDataParallel(torch_emb_zero)
dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
torch_adam = th.optim.SparseAdam(
list(torch_emb.module.parameters()) + list(torch_emb_zero.module.parameters()),
lr=0.01)
else:
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
torch_adam = th.optim.SparseAdam(list(torch_emb.module.parameters()), lr=0.01)
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
idx = th.randint(start, end, size=(4,))
dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
torch_value = torch_emb(idx)
labels = th.ones((4,)).long()
dgl_adam.zero_grad()
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
dgl_loss.backward()
dgl_adam.step()
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
torch_adam.zero_grad()
torch_loss.backward()
torch_adam.step()
if rank == 0:
after_step = dgl_emb(idx, device)
assert F.allclose(dgl_emb.weight, torch_emb.module.weight)
assert F.allclose(dgl_value, after_step) is False
th.distributed.barrier()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam(num_workers):
worker_list = []
ctx = mp.get_context('spawn')
for i in range(num_workers):
p = ctx.Process(target=start_sparse_adam_worker,
args=(i, num_workers))
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam_zero_step(num_workers):
worker_list = []
ctx = mp.get_context('spawn')
for i in range(num_workers):
p = ctx.Process(target=start_sparse_adam_worker,
args=(i, num_workers, True))
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
if __name__ == '__main__': if __name__ == '__main__':
test_sparse_adam() test_sparse_adam()
test_sparse_adam_zero_step()
test_multiprocess_sparse_adam(2)
test_multiprocess_sparse_adam(4)
test_multiprocess_sparse_adam_zero_step(2)
test_multiprocess_sparse_adam_zero_step(4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment