Unverified Commit 975eb8fc authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Distributed] Distributed node embedding and sparse optimizer (#2733)



* Draft for sparse emb

* add some notes

* Fix

* Add sparse optim for dist pytorch

* Update test

* Fix

* upd

* upd

* Fix

* Fix

* Fix bug

* add transductive exmpale

* Fix example

* Some fix

* Upd

* Fix lint

* lint

* lint

* lint

* upd

* Fix lint

* lint

* upd

* remove dead import

* update

* lint

* update unitest

* update example

* Add adam optimizer

* Add unitest and update data

* upd

* upd

* upd

* Fix docstring and fix some bug in example code

* Update rgcn readme
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-57-25.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-24-210.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
parent 2d372e35
"""Define sparse embedding and optimizer."""
import torch as th
from .... import backend as F
from .... import utils
from ...dist_tensor import DistTensor
class NodeEmbedding:
'''Distributed node embeddings.
DGL provides a distributed embedding to support models that require learnable embeddings.
DGL's distributed embeddings are mainly used for learning node embeddings of graph models.
Because distributed embeddings are part of a model, they are updated by mini-batches.
The distributed embeddings have to be updated by DGL's optimizers instead of
the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet).
To support efficient training on a graph with many nodes, the embeddings support sparse
updates. That is, only the embeddings involved in a mini-batch computation are updated.
Currently, DGL provides only one optimizer: `SparseAdagrad`. DGL will provide more
optimizers in the future.
Distributed embeddings are sharded and stored in a cluster of machines in the same way as
py:meth:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable.
Because distributed embeddings are sharded
in the same way as nodes and edges of a distributed graph, it is usually much more
efficient to access than the sparse embeddings provided by the deep learning frameworks.
Parameters
----------
num_embeddings : int
The number of embeddings. Currently, the number of embeddings has to be the same as
the number of nodes or the number of edges.
embedding_dim : int
The dimension size of embeddings.
name : str, optional
The name of the embeddings. The name can uniquely identify embeddings in a system
so that another NodeEmbedding object can referent to the same embeddings.
init_func : callable, optional
The function to create the initial data. If the init function is not provided,
the values of the embeddings are initialized to zero.
part_policy : PartitionPolicy, optional
The partition policy that assigns embeddings to different machines in the cluster.
Currently, it only supports node partition policy or edge partition policy.
The system determines the right partition policy automatically.
Examples
--------
>>> def initializer(shape, dtype):
arr = th.zeros(shape, dtype=dtype)
arr.uniform_(-1, 1)
return arr
>>> emb = dgl.distributed.nn.NodeEmbedding(g.number_of_nodes(), 10, init_func=initializer)
>>> optimizer = dgl.distributed.optim.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
... feats = emb(nids)
... loss = F.sum(feats + 1, 0)
... loss.backward()
... optimizer.step()
Note
----
When a ``NodeEmbedding`` object is used when the deep learning framework is recording
the forward computation, users have to invoke
py:meth:`~dgl.distributed.optim.SparseAdagrad.step` afterwards. Otherwise, there will be
some memory leak.
'''
def __init__(self, num_embeddings, embedding_dim, name=None,
init_func=None, part_policy=None):
self._tensor = DistTensor((num_embeddings, embedding_dim), F.float32, name,
init_func=init_func, part_policy=part_policy)
self._trace = []
self._name = name
self._num_embeddings = num_embeddings
self._embedding_dim = embedding_dim
# Check whether it is multi-gpu/distributed training or not
if th.distributed.is_initialized():
self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size()
else:
assert 'th.distributed shoud be initialized'
self._optm_state = None # track optimizer state
self._part_policy = part_policy
def __call__(self, idx, device=th.device('cpu')):
"""
node_ids : th.tensor
Index of the embeddings to collect.
device : th.device
Target device to put the collected embeddings.
Returns
-------
Tensor
The requested node embeddings
"""
idx = utils.toindex(idx).tousertensor()
emb = self._tensor[idx].to(device, non_blocking=True)
if F.is_recording():
emb = F.attach_grad(emb)
self._trace.append((idx.to(device, non_blocking=True), emb))
return emb
def reset_trace(self):
'''Reset the traced data.
'''
self._trace = []
@property
def part_policy(self):
"""Return the partition policy
Returns
-------
PartitionPolicy
partition policy
"""
return self._part_policy
@property
def name(self):
"""Return the name of the embeddings
Returns
-------
str
The name of the embeddings
"""
return self._tensor.tensor_name
@property
def kvstore(self):
"""Return the kvstore client
Returns
-------
KVClient
The kvstore client
"""
return self._tensor.kvstore
@property
def num_embeddings(self):
"""Return the number of embeddings
Returns
-------
int
The number of embeddings
"""
return self._num_embeddings
@property
def embedding_dim(self):
"""Return the dimension of embeddings
Returns
-------
int
The dimension of embeddings
"""
return self._embedding_dim
@property
def optm_state(self):
"""Return the optimizer related state tensor.
Returns
-------
tuple of torch.Tensor
The optimizer related state.
"""
return self._optm_state
@property
def weight(self):
"""Return the tensor storing the node embeddings
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return self._tensor
"""dgl distributed.optims."""
import importlib
import sys
import os
from ...backend import backend_name
from ...utils import expand_as_pair
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name)
"""dgl distributed sparse optimizer for pytorch."""
from .sparse_optim import SparseAdagrad, SparseAdam
"""Node embedding optimizers for distributed training"""
import abc
from abc import abstractmethod
import torch as th
from ...dist_tensor import DistTensor
from ...nn.pytorch import NodeEmbedding
from .utils import alltoallv_cpu, alltoall_cpu
class DistSparseGradOptimizer(abc.ABC):
r''' The abstract dist sparse optimizer.
Note: dgl dist sparse optimizer only work with dgl.distributed.nn.NodeEmbedding
Parameters
----------
params : list of NodeEmbedding
The list of NodeEmbedding.
lr : float
The learning rate.
'''
def __init__(self, params, lr):
self._params = params
self._lr = lr
self._rank = None
self._world_size = None
self._shared_cache = {}
self._clean_grad = False
self._opt_meta = {}
if th.distributed.is_initialized():
self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size()
else:
assert 'th.distributed shoud be initialized'
def step(self):
''' The step function.
The step function is invoked at the end of every batch to push the gradients
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
'''
with th.no_grad():
local_indics = {emb.name: [] for emb in self._params}
local_grads = {emb.name: [] for emb in self._params}
device = th.device('cpu')
for emb in self._params:
name = emb._tensor.name
kvstore = emb._tensor.kvstore
trace = emb._trace
trainers_per_server = self._world_size // kvstore.num_servers
idics = [t[0] for t in trace]
grads = [t[1].grad.data for t in trace]
# If the sparse embedding is not used in the previous forward step
# The idx and grad will be empty, initialize them as empty tensors to
# avoid crashing the optimizer step logic.
#
# Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding
# to certain embedding to this process.
idics = th.cat(idics, dim=0) if len(idics) != 0 else \
th.zeros((0,), dtype=th.long, device=th.device('cpu'))
grads = th.cat(grads, dim=0) if len(grads) != 0 else \
th.zeros((0, emb.embedding_dim), dtype=th.float32, device=th.device('cpu'))
device = grads.device
# will send grad to each corresponding trainer
if self._world_size > 1:
# get idx split from kvstore
idx_split = kvstore.get_partid(name, idics)
idx_split_size = []
idics_list = []
grad_list = []
# split idx and grad first
for i in range(kvstore.num_servers):
mask = idx_split == i
idx_i = idics[mask]
grad_i = grads[mask]
if trainers_per_server <= 1:
idx_split_size.append(th.tensor([idx_i.shape[0]], dtype=th.int64))
idics_list.append(idx_i)
grad_list.append(grad_i)
else:
kv_idx_split = th.remainder(idx_i, trainers_per_server).long()
for j in range(trainers_per_server):
mask = kv_idx_split == j
idx_j = idx_i[mask]
grad_j = grad_i[mask]
idx_split_size.append(th.tensor([idx_j.shape[0]], dtype=th.int64))
idics_list.append(idx_j)
grad_list.append(grad_j)
# if one machine launch multiple KVServer, they share the same storage.
# For each machine, the pytorch rank is num_trainers * machine_id + i
# use scatter to sync across trainers about the p2p tensor size
# Note: If we have GPU nccl support, we can use all_to_all to
# sync information here
gather_list = list(th.empty([self._world_size],
dtype=th.int64).chunk(self._world_size))
alltoall_cpu(self._rank, self._world_size, gather_list, idx_split_size)
# use cpu until we have GPU alltoallv
idx_gather_list = [th.empty((int(num_emb),),
dtype=idics.dtype) for num_emb in gather_list]
alltoallv_cpu(self._rank, self._world_size, idx_gather_list, idics_list)
local_indics[name] = idx_gather_list
grad_gather_list = [th.empty((int(num_emb), grads.shape[1]),
dtype=grads.dtype) for num_emb in gather_list]
alltoallv_cpu(self._rank, self._world_size, grad_gather_list, grad_list)
local_grads[name] = grad_gather_list
else:
local_indics[name] = [idics]
local_grads[name] = [grads]
if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
# do local update
for emb in self._params:
name = emb._tensor.name
idx = th.cat(local_indics[name], dim=0)
grad = th.cat(local_grads[name], dim=0)
self.update(idx.to(device, non_blocking=True),
grad.to(device, non_blocking=True), emb)
# synchronized gradient update
if self._world_size > 1:
th.distributed.barrier()
@abstractmethod
def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. we maintains gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.nn.NodeEmbedding
Sparse node embedding to update.
"""
def zero_grad(self):
"""clean grad cache
"""
self._clean_grad = True
def initializer(shape, dtype):
""" Sparse optimizer state initializer
Parameters
----------
shape : tuple of ints
The shape of the state tensor
dtype : torch dtype
The data type of the state tensor
"""
arr = th.zeros(shape, dtype=dtype)
return arr
class SparseAdagrad(DistSparseGradOptimizer):
r''' Distributed Node embedding optimizer using the Adagrad algorithm.
This optimizer implements a distributed sparse version of Adagrad algorithm for
optimizing :class:`dgl.distributed.nn.NodeEmbedding`. Being sparse means it only updates
the embeddings whose gradients have updates, which are usually a very
small portion of the total embeddings.
Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where
:math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of
the dimension :math:`j` of embedding :math:`i` at step :math:`t`.
NOTE: The support of sparse Adagrad optimizer is experimental.
Parameters
----------
params : list[dgl.distributed.nn.NodeEmbedding]
The list of dgl.distributed.nn.NodeEmbedding.
lr : float
The learning rate.
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-10
'''
def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps
# We need to register a state sum for each embedding in the kvstore.
self._state = {}
for emb in params:
assert isinstance(emb, NodeEmbedding), \
'SparseAdagrad only supports dgl.distributed.nn.NodeEmbedding'
name = emb.name + "_sum"
state = DistTensor((emb.num_embeddings, emb.embedding_dim), th.float32, name,
init_func=initializer, part_policy=emb.part_policy, is_gdata=False)
assert emb.name not in self._state, \
"{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state
def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. we maintains gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.nn.NodeEmbedding
Sparse embedding to update.
"""
eps = self._eps
clr = self._lr
exec_dev = grad.device
# the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(idx, return_inverse=True, return_counts=True)
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_sum = (grad_values * grad_values)
# update grad state
grad_state = self._state[emb.name][grad_indices].to(exec_dev, non_blocking=True)
grad_state += grad_sum
self._state[emb.name][grad_indices] = grad_state.to(th.device('cpu'), non_blocking=True)
# update emb
std_values = grad_state.add_(eps).sqrt_()
tmp = clr * grad_values / std_values
emb._tensor[grad_indices] -= tmp.to(th.device('cpu'), non_blocking=True)
class SparseAdam(DistSparseGradOptimizer):
r''' Distributed Node embedding optimizer using the Adam algorithm.
This optimizer implements a distributed sparse version of Adam algorithm for
optimizing :class:`dgl.distributed.nn.NodeEmbedding`. Being sparse means it only updates
the embeddings whose gradients have updates, which are usually a very
small portion of the total embeddings.
Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter
in the embeddings, where
:math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`,
:math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`,
:math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and
:math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i`
at step :math:`t`.
NOTE: The support of sparse Adam optimizer is experimental.
Parameters
----------
params : list[dgl.distributed.nn.NodeEmbedding]
The list of dgl.distributed.nn.NodeEmbedding.
lr : float
The learning rate.
betas : tuple[float, float], Optional
Coefficients used for computing running averages of gradient and its square.
Default: (0.9, 0.999)
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-8
'''
def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
super(SparseAdam, self).__init__(params, lr)
self._eps = eps
# We need to register a state sum for each embedding in the kvstore.
self._beta1 = betas[0]
self._beta2 = betas[1]
self._state = {}
for emb in params:
assert isinstance(emb, NodeEmbedding), \
'SparseAdam only supports dgl.distributed.nn.NodeEmbedding'
state_step = DistTensor((emb.num_embeddings,),
th.float32, emb.name + "_step",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False)
state_mem = DistTensor((emb.num_embeddings, emb.embedding_dim),
th.float32, emb.name + "_mem",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False)
state_power = DistTensor((emb.num_embeddings, emb.embedding_dim),
th.float32, emb.name + "_power",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False)
state = (state_step, state_mem, state_power)
assert emb.name not in self._state, \
"{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state
def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. we maintains gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.nn.NodeEmbedding
Sparse embedding to update.
"""
beta1 = self._beta1
beta2 = self._beta2
eps = self._eps
clr = self._lr
state_step, state_mem, state_power = self._state[emb.name]
state_dev = th.device('cpu')
exec_dev = grad.device
# the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(idx, return_inverse=True, return_counts=True)
# update grad state
state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev, non_blocking=True)
orig_mem = state_mem[state_idx].to(exec_dev, non_blocking=True)
orig_power = state_power[state_idx].to(exec_dev, non_blocking=True)
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_mem = grad_values
grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem
update_power = beta2 * orig_power + (1.-beta2) * grad_power
state_mem[state_idx] = update_mem.to(state_dev, non_blocking=True)
state_power[state_idx] = update_power.to(state_dev, non_blocking=True)
update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev),
state_step)).unsqueeze(1)
update_power_corr = update_power / (1. - th.pow(th.tensor(beta2, device=exec_dev),
state_step)).unsqueeze(1)
std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
emb._tensor[state_idx] -= std_values.to(state_dev)
"""Provide utils for distributed sparse optimizers
"""
import torch as th
import torch.distributed as dist
def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape.
Parameters
----------
rank : int
The rank of current worker
world_size : int
The size of the entire
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
"""
input_tensor_list = [tensor.to(th.device('cpu')) for tensor in input_tensor_list]
for i in range(world_size):
dist.scatter(output_tensor_list[i], input_tensor_list if i == rank else [], src=i)
def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list.
Parameters
----------
rank : int
The rank of current worker
world_size : int
The size of the entire
output_tensor_list : List of tensor
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
"""
# send tensor to each target trainer using torch.distributed.isend
# isend is async
senders = []
for i in range(world_size):
if i == rank:
output_tensor_list[i] = input_tensor_list[i].to(th.device('cpu'))
else:
sender = dist.isend(input_tensor_list[i].to(th.device('cpu')), dst=i)
senders.append(sender)
for i in range(world_size):
if i != rank:
dist.recv(output_tensor_list[i], src=i)
th.distributed.barrier()
...@@ -6,24 +6,23 @@ from .dist_tensor import DistTensor ...@@ -6,24 +6,23 @@ from .dist_tensor import DistTensor
class DistEmbedding: class DistEmbedding:
'''Distributed embeddings. '''Distributed embeddings.
DGL provides a distributed embedding to support models that require learnable embeddings. DGL provides a distributed embedding to support models that require learnable embeddings.
DGL's distributed embeddings are mainly used for learning node embeddings of graph models. DGL's distributed embeddings are mainly used for learning node embeddings of graph models.
Because distributed embeddings are part of a model, they are updated by mini-batches. Because distributed embeddings are part of a model, they are updated by mini-batches.
The distributed embeddings have to be updated by DGL's optimizers instead of The distributed embeddings have to be updated by DGL's optimizers instead of
the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet). the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet).
To support efficient training on a graph with many nodes, the embeddings support sparse To support efficient training on a graph with many nodes, the embeddings support sparse
updates. That is, only the embeddings involved in a mini-batch computation are updated. updates. That is, only the embeddings involved in a mini-batch computation are updated.
Currently, DGL provides only one optimizer: `SparseAdagrad`. DGL will provide more Currently, DGL provides only one optimizer: `SparseAdagrad`. DGL will provide more
optimizers in the future. optimizers in the future.
Distributed embeddings are sharded and stored in a cluster of machines in the same way as Distributed embeddings are sharded and stored in a cluster of machines in the same way as
py:meth:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable. py:meth:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable.
Because distributed embeddings are sharded Because distributed embeddings are sharded
in the same way as nodes and edges of a distributed graph, it is usually much more in the same way as nodes and edges of a distributed graph, it is usually much more
efficient to access than the sparse embeddings provided by the deep learning frameworks. efficient to access than the sparse embeddings provided by the deep learning frameworks.
DEPRECATED: Please use dgl.distributed.nn.NodeEmbedding instead.
Parameters Parameters
---------- ----------
num_embeddings : int num_embeddings : int
...@@ -41,7 +40,6 @@ class DistEmbedding: ...@@ -41,7 +40,6 @@ class DistEmbedding:
The partition policy that assigns embeddings to different machines in the cluster. The partition policy that assigns embeddings to different machines in the cluster.
Currently, it only supports node partition policy or edge partition policy. Currently, it only supports node partition policy or edge partition policy.
The system determines the right partition policy automatically. The system determines the right partition policy automatically.
Examples Examples
-------- --------
>>> def initializer(shape, dtype): >>> def initializer(shape, dtype):
...@@ -55,7 +53,6 @@ class DistEmbedding: ...@@ -55,7 +53,6 @@ class DistEmbedding:
... loss = F.sum(feats + 1, 0) ... loss = F.sum(feats + 1, 0)
... loss.backward() ... loss.backward()
... optimizer.step() ... optimizer.step()
Note Note
---- ----
When a ``DistEmbedding`` object is used when the deep learning framework is recording When a ``DistEmbedding`` object is used when the deep learning framework is recording
...@@ -83,7 +80,6 @@ class DistEmbedding: ...@@ -83,7 +80,6 @@ class DistEmbedding:
class SparseAdagradUDF: class SparseAdagradUDF:
''' The UDF to update the embeddings with sparse Adagrad. ''' The UDF to update the embeddings with sparse Adagrad.
Parameters Parameters
---------- ----------
lr : float lr : float
...@@ -94,10 +90,8 @@ class SparseAdagradUDF: ...@@ -94,10 +90,8 @@ class SparseAdagradUDF:
def __call__(self, data_store, name, indices, data): def __call__(self, data_store, name, indices, data):
''' Update the embeddings with sparse Adagrad. ''' Update the embeddings with sparse Adagrad.
This function runs on the KVStore server. It updates the gradients by scaling them This function runs on the KVStore server. It updates the gradients by scaling them
according to the state sum. according to the state sum.
Parameters Parameters
---------- ----------
data_store : dict of data data_store : dict of data
...@@ -125,27 +119,20 @@ def _init_state(shape, dtype): ...@@ -125,27 +119,20 @@ def _init_state(shape, dtype):
class SparseAdagrad: class SparseAdagrad:
r''' The sparse Adagrad optimizer. r''' The sparse Adagrad optimizer.
This optimizer implements a lightweight version of Adagrad algorithm for optimizing This optimizer implements a lightweight version of Adagrad algorithm for optimizing
:func:`dgl.distributed.DistEmbedding`. In each mini-batch, it only updates the embeddings :func:`dgl.distributed.DistEmbedding`. In each mini-batch, it only updates the embeddings
involved in the mini-batch to support efficient training on a graph with many involved in the mini-batch to support efficient training on a graph with many
nodes and edges. nodes and edges.
Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where
:math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of :math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of
the dimension :math:`j` of embedding :math:`i` at step :math:`t`. the dimension :math:`j` of embedding :math:`i` at step :math:`t`.
Instead of maintaining :math:`G_{t,i,j}`, this implementation maintains :math:`G_{t,i}` Instead of maintaining :math:`G_{t,i,j}`, this implementation maintains :math:`G_{t,i}`
for every embedding :math:`i`: for every embedding :math:`i`:
.. math:: .. math::
G_{t,i}=G_{t-1,i}+ \frac{1}{p} \sum_{0 \le j \lt p}g_{t,i,j}^2 G_{t,i}=G_{t-1,i}+ \frac{1}{p} \sum_{0 \le j \lt p}g_{t,i,j}^2
where :math:`p` is the dimension size of an embedding. where :math:`p` is the dimension size of an embedding.
The benefit of the implementation is that it consumes much smaller memory and runs The benefit of the implementation is that it consumes much smaller memory and runs
much faster if users' model requires learnable embeddings for nodes or edges. much faster if users' model requires learnable embeddings for nodes or edges.
Parameters Parameters
---------- ----------
params : list of DistEmbeddings params : list of DistEmbeddings
...@@ -170,7 +157,6 @@ class SparseAdagrad: ...@@ -170,7 +157,6 @@ class SparseAdagrad:
def step(self): def step(self):
''' The step function. ''' The step function.
The step function is invoked at the end of every batch to push the gradients The step function is invoked at the end of every batch to push the gradients
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings. of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
''' '''
......
...@@ -13,7 +13,6 @@ from dgl.heterograph_index import create_unitgraph_from_coo ...@@ -13,7 +13,6 @@ from dgl.heterograph_index import create_unitgraph_from_coo
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
from dgl.distributed import SparseAdagrad, DistEmbedding
from numpy.testing import assert_almost_equal from numpy.testing import assert_almost_equal
import backend as F import backend as F
import math import math
...@@ -148,6 +147,67 @@ def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_ed ...@@ -148,6 +147,67 @@ def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_ed
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(g, num_clients, num_nodes, num_edges) check_dist_graph(g, num_clients, num_nodes, num_edges)
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
check_dist_emb(g, num_clients, num_nodes, num_edges)
def check_dist_emb(g, num_clients, num_nodes, num_edges):
from dgl.distributed.optim import SparseAdagrad
from dgl.distributed.nn import NodeEmbedding
# Test sparse emb
try:
emb = NodeEmbedding(g.number_of_nodes(), 1, 'emb1', emb_init)
lr = 0.001
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats = emb(nids)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32,
'emb1_sum', policy)
if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = NodeEmbedding(g.number_of_nodes(), 1, 'emb2', emb_init)
with F.no_grad():
feats1 = emb(nids)
assert np.all(F.asnumpy(feats1) == 0)
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats1 = emb(nids)
feats2 = emb(nids)
feats = F.cat([feats1, feats2], 0)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
with F.no_grad():
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e:
pass
def check_dist_graph(g, num_clients, num_nodes, num_edges): def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test API # Test API
assert g.number_of_nodes() == num_nodes assert g.number_of_nodes() == num_nodes
...@@ -200,55 +260,6 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -200,55 +260,6 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
except: except:
pass pass
# Test sparse emb
try:
emb = DistEmbedding(g.number_of_nodes(), 1, 'emb1', emb_init)
lr = 0.001
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats = emb(nids)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32,
'emb1_sum', policy)
if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = DistEmbedding(g.number_of_nodes(), 1, 'emb2', emb_init)
with F.no_grad():
feats1 = emb(nids)
assert np.all(F.asnumpy(feats1) == 0)
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats1 = emb(nids)
feats2 = emb(nids)
feats = F.cat([feats1, feats2], 0)
assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
loss = F.sum(feats + 1, 0)
loss.backward()
optimizer.step()
with F.no_grad():
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e:
pass
# Test write data # Test write data
new_feats = F.ones((len(nids), 2), F.int32, F.cpu()) new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
g.ndata['test1'][nids] = new_feats g.ndata['test1'][nids] = new_feats
...@@ -274,6 +285,44 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -274,6 +285,44 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
prepare_dist()
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_2'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients,
g.number_of_nodes(),
g.number_of_edges()))
p.start()
cli_ps.append(p)
for p in cli_ps:
p.join()
for p in serv_ps:
p.join()
print('clients have terminated')
def check_server_client(shared_mem, num_servers, num_clients): def check_server_client(shared_mem, num_servers, num_clients):
prepare_dist() prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -461,6 +510,16 @@ def test_server_client(): ...@@ -461,6 +510,16 @@ def test_server_client():
check_server_client(True, 2, 2) check_server_client(True, 2, 2)
check_server_client(False, 2, 2) check_server_client(False, 2, 2)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed NodeEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed NodeEmbedding")
def test_dist_emb_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed'
check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(False, 2, 2)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone(): def test_standalone():
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
...@@ -481,6 +540,27 @@ def test_standalone(): ...@@ -481,6 +540,27 @@ def test_standalone():
print(e) print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed NodeEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed NodeEmbedding")
def test_standalone_node_emb():
os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_3'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
dgl.distributed.initialize("kv_ip_config.txt")
dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
try:
check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
except Exception as e:
print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split(): def test_split():
#prepare_dist() #prepare_dist()
...@@ -617,3 +697,5 @@ if __name__ == '__main__': ...@@ -617,3 +697,5 @@ if __name__ == '__main__':
test_split_even() test_split_even()
test_server_client() test_server_client()
test_standalone() test_standalone()
test_standalone_node_emb()
import os
os.environ['OMP_NUM_THREADS'] = '1'
import dgl
import sys
import numpy as np
import time
import socket
from scipy import sparse as spsp
import torch as th
from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition_book
import multiprocessing as mp
from dgl import function as fn
import backend as F
import unittest
import pickle
import random
from dgl.distributed.nn import NodeEmbedding
from dgl.distributed.optim import SparseAdagrad, SparseAdam
def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
return dgl.from_scipy(arr)
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
def prepare_dist():
ip_config = open("optim_ip_config.txt", "w")
ip_addr = get_local_usable_addr()
ip_config.write('{}\n'.format(ip_addr))
ip_config.close()
def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
g = DistGraphServer(server_id, "optim_ip_config.txt", num_clients, server_count,
'/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem)
print('start server', server_id)
g.start()
def initializer(shape, dtype):
arr = th.zeros(shape, dtype=dtype)
th.manual_seed(0)
th.nn.init.uniform_(arr, 0, 1.0)
return arr
def run_client(graph_name, cli_id, part_id, server_count):
device=F.ctx()
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("optim_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
num_nodes = g.number_of_nodes()
emb_dim = 4
dgl_emb = NodeEmbedding(num_nodes, emb_dim, name='optim', init_func=initializer, part_policy=policy)
dgl_emb_zero = NodeEmbedding(num_nodes, emb_dim, name='optim-zero', init_func=initializer, part_policy=policy)
dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
dgl_adam._world_size = 1
dgl_adam._rank = 0
torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
th.manual_seed(0)
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
th.manual_seed(0)
th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
torch_adam = th.optim.SparseAdam(
list(torch_emb.parameters()) + list(torch_emb_zero.parameters()), lr=0.01)
labels = th.ones((4,)).long()
idx = th.randint(0, num_nodes, size=(4,))
dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
torch_value = torch_emb(idx)
torch_adam.zero_grad()
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
torch_loss.backward()
torch_adam.step()
dgl_adam.zero_grad()
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
dgl_loss.backward()
dgl_adam.step()
assert F.allclose(dgl_emb.weight[0 : num_nodes//2], torch_emb.weight[0 : num_nodes//2])
def check_sparse_adam(num_trainer=1, shared_mem=True):
prepare_dist()
g = create_random_graph(2000)
num_servers = num_trainer
num_clients = num_trainer
num_parts = 1
graph_name = 'dist_graph_test'
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, cli_id, 0, num_servers))
p.start()
cli_ps.append(p)
for p in cli_ps:
p.join()
for p in serv_ps:
p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_sparse_opt():
os.environ['DGL_DIST_MODE'] = 'distributed'
check_sparse_adam(1, True)
check_sparse_adam(1, False)
if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True)
test_sparse_opt()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment