Unverified Commit 2caa6bd0 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Dist] Enable save and load for Distributed Optimizer (#4752)



* add save/load for distributed  optimizer
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 5833efe0
...@@ -48,10 +48,10 @@ Distributed embedding optimizer ...@@ -48,10 +48,10 @@ Distributed embedding optimizer
------------------------- -------------------------
.. autoclass:: dgl.distributed.optim.SparseAdagrad .. autoclass:: dgl.distributed.optim.SparseAdagrad
:members: step :members: step, save, load
.. autoclass:: dgl.distributed.optim.SparseAdam .. autoclass:: dgl.distributed.optim.SparseAdam
:members: step :members: step, save, load
Distributed workload split Distributed workload split
-------------------------- --------------------------
......
"""Node embedding optimizers for distributed training""" """Node embedding optimizers for distributed training"""
import abc import abc
import warnings
from abc import abstractmethod from abc import abstractmethod
from os.path import exists
import torch as th import torch as th
import dgl
from .... import backend as F
from ...dist_tensor import DistTensor from ...dist_tensor import DistTensor
from ...nn.pytorch import DistEmbedding from ...nn.pytorch import DistEmbedding
from .utils import alltoallv_cpu, alltoall_cpu from .utils import alltoall_cpu, alltoallv_cpu
from ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY
EMB_STATES = "emb_states"
WORLD_SIZE = "world_size"
IDS = "ids"
PARAMS = "params"
STATES = "states"
class DistSparseGradOptimizer(abc.ABC): class DistSparseGradOptimizer(abc.ABC):
r''' The abstract dist sparse optimizer. r"""The abstract dist sparse optimizer.
Note: dgl dist sparse optimizer only work with dgl.distributed.DistEmbedding Note: dgl dist sparse optimizer only work with dgl.distributed.DistEmbedding
...@@ -18,7 +32,8 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -18,7 +32,8 @@ class DistSparseGradOptimizer(abc.ABC):
The list of DistEmbedding. The list of DistEmbedding.
lr : float lr : float
The learning rate. The learning rate.
''' """
def __init__(self, params, lr): def __init__(self, params, lr):
self._params = params self._params = params
self._lr = lr self._lr = lr
...@@ -27,6 +42,9 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -27,6 +42,9 @@ class DistSparseGradOptimizer(abc.ABC):
self._shared_cache = {} self._shared_cache = {}
self._clean_grad = False self._clean_grad = False
self._opt_meta = {} self._opt_meta = {}
self._state = {}
## collect all hyper parameters for save
self._defaults = {}
if th.distributed.is_initialized(): if th.distributed.is_initialized():
self._rank = th.distributed.get_rank() self._rank = th.distributed.get_rank()
...@@ -35,19 +53,215 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -35,19 +53,215 @@ class DistSparseGradOptimizer(abc.ABC):
self._rank = 0 self._rank = 0
self._world_size = 1 self._world_size = 1
def local_state_dict(self):
"""Return the state pertaining to current rank of the optimizer.
Returns
-------
dict
Local state dict
Example Dict of Adagrad Optimizer:
.. code-block:: json
{
"params": {
"_lr": 0.01,
"_eps": "1e-8",
"world_size": 2
},
"emb_states": {
"emb_name1": {
"ids": [0, 2, 4, 6 ,8 ,10], ## tensor,
"emb_name1_sum": [0.1 , 0.2, 0.5, 0.1, 0.2] ## tensor,
},
"emb_name2": {
"ids": [0, 2, 4, 6 ,8 ,10], ## tensor,
"emb_name2_sum": [0.3 , 0.2, 0.4, 0.5, 0.2] ## tensor,
}
}
}
:param json: json object
See Also
--------
load_local_state_dict
"""
local_state_dict = {}
local_state_dict[EMB_STATES] = {}
local_state_dict[PARAMS] = {WORLD_SIZE: self._world_size}
for emb in self._params:
trainers_per_machine = self._world_size // max(
1, dgl.distributed.get_num_machines()
)
emb_state_dict = {}
part_policy = (
emb.part_policy if emb.part_policy else emb.weight.part_policy
)
idx = self._get_local_ids(part_policy)
if trainers_per_machine > 1:
kv_idx_split = (idx % trainers_per_machine).long()
local_rank = self._rank % trainers_per_machine
mask = kv_idx_split == local_rank
idx = F.boolean_mask(idx, mask)
emb_state_dict.update({IDS: idx})
emb_state = {}
states = (
list(self._state[emb.name])
if isinstance(self._state[emb.name], tuple)
else [self._state[emb.name]]
)
emb_state = {state.name: state[idx] for state in states}
emb_state_dict.update({STATES: emb_state})
local_state_dict[EMB_STATES].update({emb.name: emb_state_dict})
local_state_dict[PARAMS].update(self._defaults)
return local_state_dict
def load_local_state_dict(self, local_state_dict):
"""Load the local state from the input state_dict,
updating the optimizer as needed.
Parameters
----------
local_state_dict : dict
Optimizer state; should be an object returned
from a call to local_state_dict().
See Also
--------
local_state_dict
"""
for emb_name, emb_state in local_state_dict[EMB_STATES].items():
idx = emb_state[IDS]
# As state of an embedding of different optimizers can be a single
# DistTensor(Adagrad) or a tuple(Adam) of that, converting it to list for
# consistency. The list contains reference(s) to original DistTensor(s).
states = (
list(self._state[emb_name])
if isinstance(self._state[emb_name], tuple)
else [self._state[emb_name]]
)
if len(emb_state[STATES]) != len(states):
raise ValueError(
f"loaded state dict has a different number of states"
f" of embedding {emb_name}"
)
name_to_index = {
state.name: index for index, state in enumerate(states)
}
for name, state in emb_state[STATES].items():
if name not in name_to_index:
raise ValueError(
"loaded state dict contains a state {name}"
"that can't be found in the optimizer states"
)
state_idx = name_to_index[name]
state = state.to(
th.device("cpu"), states[name_to_index[name]].dtype
)
states[state_idx][idx] = state
self._defaults.update(local_state_dict[PARAMS])
self.__dict__.update(local_state_dict[PARAMS])
def save(self, f):
"""Save the local state_dict to disk on per rank.
Saved dict contains 2 parts:
* 'params': hyper parameters of the optimizer.
* 'emb_states': partial optimizer states, each embedding contains 2 items:
1. ```ids```: global id of the nodes/edges stored in this rank.
2. ```states```: state data corrseponding to ```ids```.
NOTE: This needs to be called on all ranks.
Parameters
----------
f : Union[str, os.PathLike]
The path of the file to save to.
See Also
--------
load
"""
if self._world_size > 1:
th.distributed.barrier()
f = f if isinstance(f, str) else str(f, "UTF-8")
f = f"{f}_{self._rank}"
th.save(self.local_state_dict(), f)
if self._world_size > 1:
th.distributed.barrier()
def load(self, f):
"""Load the local state of the optimizer from the file on per rank.
NOTE: This needs to be called on all ranks.
Parameters
----------
f : Union[str, os.PathLike]
The path of the file to load from.
See Also
--------
save
"""
if self._world_size > 1:
th.distributed.barrier()
f = f if isinstance(f, str) else str(f, "UTF-8")
f_attach_rank = f"{f}_{self._rank}"
# Don't throw error here to support device number scale-out
# after reloading, but make sure your hyper parameter is same
# as before because new added local optimizers will be filled
# in nothing
if not exists(f_attach_rank):
warnings.warn(f"File {f_attach_rank} can't be found, load nothing.")
else:
old_world_size = self._load_state_from(f_attach_rank)
# Device number scale-in
if self._world_size < old_world_size:
for rank in range(
self._rank + self._world_size,
old_world_size,
self._world_size,
):
self._load_state_from(f"{f}_{rank}")
if self._world_size > 1:
th.distributed.barrier()
def _load_state_from(self, f):
local_state_dict = th.load(f)
world_size = local_state_dict[PARAMS].pop(WORLD_SIZE)
self.load_local_state_dict(local_state_dict)
return world_size
def _get_local_ids(self, part_policy):
if EDGE_PART_POLICY in part_policy.policy_str:
return part_policy.partition_book.partid2eids(
part_policy.part_id, part_policy.type_name
)
elif NODE_PART_POLICY in part_policy.policy_str:
return part_policy._partition_book.partid2nids(
part_policy.part_id, part_policy.type_name
)
else:
raise RuntimeError(
"Cannot support policy: %s " % part_policy.policy_str
)
def step(self): def step(self):
''' The step function. """The step function.
The step function is invoked at the end of every batch to push the gradients The step function is invoked at the end of every batch to push the gradients
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings. of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
''' """
with th.no_grad(): with th.no_grad():
local_indics = {emb.name: [] for emb in self._params} local_indics = {emb.name: [] for emb in self._params}
local_grads = {emb.name: [] for emb in self._params} local_grads = {emb.name: [] for emb in self._params}
device = th.device('cpu') device = th.device("cpu")
for emb in self._params: for emb in self._params:
name = emb._tensor.name name = emb.weight.name
kvstore = emb._tensor.kvstore kvstore = emb.weight.kvstore
trainers_per_server = self._world_size // kvstore.num_servers trainers_per_server = self._world_size // kvstore.num_servers
idics = [] idics = []
...@@ -65,10 +279,20 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -65,10 +279,20 @@ class DistSparseGradOptimizer(abc.ABC):
# Note: we cannot skip the gradient exchange and update steps as other # Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding # working processes may send gradient update requests corresponding
# to certain embedding to this process. # to certain embedding to this process.
idics = th.cat(idics, dim=0) if len(idics) != 0 else \ idics = (
th.zeros((0,), dtype=th.long, device=th.device('cpu')) th.cat(idics, dim=0)
grads = th.cat(grads, dim=0) if len(grads) != 0 else \ if len(idics) != 0
th.zeros((0, emb.embedding_dim), dtype=th.float32, device=th.device('cpu')) else th.zeros((0,), dtype=th.long, device=th.device("cpu"))
)
grads = (
th.cat(grads, dim=0)
if len(grads) != 0
else th.zeros(
(0, emb.embedding_dim),
dtype=th.float32,
device=th.device("cpu"),
)
)
device = grads.device device = grads.device
# will send grad to each corresponding trainer # will send grad to each corresponding trainer
...@@ -85,36 +309,67 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -85,36 +309,67 @@ class DistSparseGradOptimizer(abc.ABC):
grad_i = grads[mask] grad_i = grads[mask]
if trainers_per_server <= 1: if trainers_per_server <= 1:
idx_split_size.append(th.tensor([idx_i.shape[0]], dtype=th.int64)) idx_split_size.append(
th.tensor([idx_i.shape[0]], dtype=th.int64)
)
idics_list.append(idx_i) idics_list.append(idx_i)
grad_list.append(grad_i) grad_list.append(grad_i)
else: else:
kv_idx_split = th.remainder(idx_i, trainers_per_server).long() kv_idx_split = th.remainder(
idx_i, trainers_per_server
).long()
for j in range(trainers_per_server): for j in range(trainers_per_server):
mask = kv_idx_split == j mask = kv_idx_split == j
idx_j = idx_i[mask] idx_j = idx_i[mask]
grad_j = grad_i[mask] grad_j = grad_i[mask]
idx_split_size.append(th.tensor([idx_j.shape[0]], dtype=th.int64)) idx_split_size.append(
th.tensor([idx_j.shape[0]], dtype=th.int64)
)
idics_list.append(idx_j) idics_list.append(idx_j)
grad_list.append(grad_j) grad_list.append(grad_j)
# if one machine launch multiple KVServer, they share the same storage. # if one machine launch multiple KVServer, they share the same storage.
# For each machine, the pytorch rank is num_trainers * machine_id + i # For each machine, the pytorch rank is num_trainers *
# machine_id + i
# use scatter to sync across trainers about the p2p tensor size # use scatter to sync across trainers about the p2p tensor size
# Note: If we have GPU nccl support, we can use all_to_all to # Note: If we have GPU nccl support, we can use all_to_all to
# sync information here # sync information here
gather_list = list(th.empty([self._world_size], gather_list = list(
dtype=th.int64).chunk(self._world_size)) th.empty([self._world_size], dtype=th.int64).chunk(
alltoall_cpu(self._rank, self._world_size, gather_list, idx_split_size) self._world_size
)
)
alltoall_cpu(
self._rank,
self._world_size,
gather_list,
idx_split_size,
)
# use cpu until we have GPU alltoallv # use cpu until we have GPU alltoallv
idx_gather_list = [th.empty((int(num_emb),), idx_gather_list = [
dtype=idics.dtype) for num_emb in gather_list] th.empty((int(num_emb),), dtype=idics.dtype)
alltoallv_cpu(self._rank, self._world_size, idx_gather_list, idics_list) for num_emb in gather_list
]
alltoallv_cpu(
self._rank,
self._world_size,
idx_gather_list,
idics_list,
)
local_indics[name] = idx_gather_list local_indics[name] = idx_gather_list
grad_gather_list = [th.empty((int(num_emb), grads.shape[1]), grad_gather_list = [
dtype=grads.dtype) for num_emb in gather_list] th.empty(
alltoallv_cpu(self._rank, self._world_size, grad_gather_list, grad_list) (int(num_emb), grads.shape[1]), dtype=grads.dtype
)
for num_emb in gather_list
]
alltoallv_cpu(
self._rank,
self._world_size,
grad_gather_list,
grad_list,
)
local_grads[name] = grad_gather_list local_grads[name] = grad_gather_list
else: else:
local_indics[name] = [idics] local_indics[name] = [idics]
...@@ -128,12 +383,14 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -128,12 +383,14 @@ class DistSparseGradOptimizer(abc.ABC):
# do local update # do local update
for emb in self._params: for emb in self._params:
name = emb._tensor.name name = emb.weight.name
idx = th.cat(local_indics[name], dim=0) idx = th.cat(local_indics[name], dim=0)
grad = th.cat(local_grads[name], dim=0) grad = th.cat(local_grads[name], dim=0)
self.update(idx.to(device, non_blocking=True), self.update(
grad.to(device, non_blocking=True), emb) idx.to(device, non_blocking=True),
grad.to(device, non_blocking=True),
emb,
)
# synchronized gradient update # synchronized gradient update
if self._world_size > 1: if self._world_size > 1:
...@@ -141,7 +398,7 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -141,7 +398,7 @@ class DistSparseGradOptimizer(abc.ABC):
@abstractmethod @abstractmethod
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner """Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately. each embedding so they can be updated separately.
...@@ -156,12 +413,12 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -156,12 +413,12 @@ class DistSparseGradOptimizer(abc.ABC):
""" """
def zero_grad(self): def zero_grad(self):
"""clean grad cache """clean grad cache"""
"""
self._clean_grad = True self._clean_grad = True
def initializer(shape, dtype): def initializer(shape, dtype):
""" Sparse optimizer state initializer """Sparse optimizer state initializer
Parameters Parameters
---------- ----------
...@@ -173,8 +430,9 @@ def initializer(shape, dtype): ...@@ -173,8 +430,9 @@ def initializer(shape, dtype):
arr = th.zeros(shape, dtype=dtype) arr = th.zeros(shape, dtype=dtype)
return arr return arr
class SparseAdagrad(DistSparseGradOptimizer): class SparseAdagrad(DistSparseGradOptimizer):
r''' Distributed Node embedding optimizer using the Adagrad algorithm. r"""Distributed Node embedding optimizer using the Adagrad algorithm.
This optimizer implements a distributed sparse version of Adagrad algorithm for This optimizer implements a distributed sparse version of Adagrad algorithm for
optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates
...@@ -196,25 +454,34 @@ class SparseAdagrad(DistSparseGradOptimizer): ...@@ -196,25 +454,34 @@ class SparseAdagrad(DistSparseGradOptimizer):
eps : float, Optional eps : float, Optional
The term added to the denominator to improve numerical stability The term added to the denominator to improve numerical stability
Default: 1e-10 Default: 1e-10
''' """
def __init__(self, params, lr, eps=1e-10): def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr) super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps self._eps = eps
self._defaults = {"_lr": lr, "_eps": eps}
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
self._state = {}
for emb in params: for emb in params:
assert isinstance(emb, DistEmbedding), \ assert isinstance(
'SparseAdagrad only supports dgl.distributed.DistEmbedding' emb, DistEmbedding
), "SparseAdagrad only supports dgl.distributed.DistEmbedding"
name = emb.name + "_sum" name = emb.name + "_sum"
state = DistTensor((emb.num_embeddings, emb.embedding_dim), th.float32, name, state = DistTensor(
init_func=initializer, part_policy=emb.part_policy, is_gdata=False) (emb.num_embeddings, emb.embedding_dim),
assert emb.name not in self._state, \ th.float32,
"{} already registered in the optimizer".format(emb.name) name,
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
assert (
emb.name not in self._state
), "{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state self._state[emb.name] = state
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner """Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately. each embedding so they can be updated separately.
...@@ -230,20 +497,24 @@ class SparseAdagrad(DistSparseGradOptimizer): ...@@ -230,20 +497,24 @@ class SparseAdagrad(DistSparseGradOptimizer):
eps = self._eps eps = self._eps
clr = self._lr clr = self._lr
state_dev = th.device('cpu') state_dev = th.device("cpu")
exec_dev = grad.device exec_dev = grad.device
# only perform async copies cpu -> gpu, or gpu-> gpu, but block # only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished # when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu # before operating on the data on the cpu
state_block = state_dev == th.device('cpu') and exec_dev != state_dev state_block = state_dev == th.device("cpu") and exec_dev != state_dev
# the update is non-linear so indices must be unique # the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(idx, return_inverse=True, return_counts=True) grad_indices, inverse, cnt = th.unique(
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev) idx, return_inverse=True, return_counts=True
)
grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=exec_dev
)
grad_values.index_add_(0, inverse, grad) grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1) grad_values = grad_values / cnt.unsqueeze(1)
grad_sum = (grad_values * grad_values) grad_sum = grad_values * grad_values
# update grad state # update grad state
grad_state = self._state[emb.name][grad_indices].to(exec_dev) grad_state = self._state[emb.name][grad_indices].to(exec_dev)
...@@ -273,8 +544,9 @@ class SparseAdagrad(DistSparseGradOptimizer): ...@@ -273,8 +544,9 @@ class SparseAdagrad(DistSparseGradOptimizer):
std_event.wait() std_event.wait()
emb._tensor[grad_indices] -= tmp_dst emb._tensor[grad_indices] -= tmp_dst
class SparseAdam(DistSparseGradOptimizer): class SparseAdam(DistSparseGradOptimizer):
r''' Distributed Node embedding optimizer using the Adam algorithm. r"""Distributed Node embedding optimizer using the Adam algorithm.
This optimizer implements a distributed sparse version of Adam algorithm for This optimizer implements a distributed sparse version of Adam algorithm for
optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates
...@@ -303,40 +575,57 @@ class SparseAdam(DistSparseGradOptimizer): ...@@ -303,40 +575,57 @@ class SparseAdam(DistSparseGradOptimizer):
eps : float, Optional eps : float, Optional
The term added to the denominator to improve numerical stability The term added to the denominator to improve numerical stability
Default: 1e-8 Default: 1e-8
''' """
def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08): def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
super(SparseAdam, self).__init__(params, lr) super(SparseAdam, self).__init__(params, lr)
self._eps = eps self._eps = eps
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
self._beta1 = betas[0] self._beta1 = betas[0]
self._beta2 = betas[1] self._beta2 = betas[1]
self._state = {} self._defaults = {
"_lr": lr,
"_eps": eps,
"_beta1": betas[0],
"_beta2": betas[1],
}
for emb in params: for emb in params:
assert isinstance(emb, DistEmbedding), \ assert isinstance(
'SparseAdam only supports dgl.distributed.DistEmbedding' emb, DistEmbedding
), "SparseAdam only supports dgl.distributed.DistEmbedding"
state_step = DistTensor((emb.num_embeddings,),
th.float32, emb.name + "_step", state_step = DistTensor(
init_func=initializer, (emb.num_embeddings,),
part_policy=emb.part_policy, th.float32,
is_gdata=False) emb.name + "_step",
state_mem = DistTensor((emb.num_embeddings, emb.embedding_dim), init_func=initializer,
th.float32, emb.name + "_mem", part_policy=emb.part_policy,
init_func=initializer, is_gdata=False,
part_policy=emb.part_policy, )
is_gdata=False) state_mem = DistTensor(
state_power = DistTensor((emb.num_embeddings, emb.embedding_dim), (emb.num_embeddings, emb.embedding_dim),
th.float32, emb.name + "_power", th.float32,
init_func=initializer, emb.name + "_mem",
part_policy=emb.part_policy, init_func=initializer,
is_gdata=False) part_policy=emb.part_policy,
is_gdata=False,
)
state_power = DistTensor(
(emb.num_embeddings, emb.embedding_dim),
th.float32,
emb.name + "_power",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
state = (state_step, state_mem, state_power) state = (state_step, state_mem, state_power)
assert emb.name not in self._state, \ assert (
"{} already registered in the optimizer".format(emb.name) emb.name not in self._state
), "{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state self._state[emb.name] = state
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
""" Update embeddings in a sparse manner """Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately. each embedding so they can be updated separately.
...@@ -355,16 +644,18 @@ class SparseAdam(DistSparseGradOptimizer): ...@@ -355,16 +644,18 @@ class SparseAdam(DistSparseGradOptimizer):
clr = self._lr clr = self._lr
state_step, state_mem, state_power = self._state[emb.name] state_step, state_mem, state_power = self._state[emb.name]
state_dev = th.device('cpu') state_dev = th.device("cpu")
exec_dev = grad.device exec_dev = grad.device
# only perform async copies cpu -> gpu, or gpu-> gpu, but block # only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished # when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu # before operating on the data on the cpu
state_block = state_dev == th.device('cpu') and exec_dev != state_dev state_block = state_dev == th.device("cpu") and exec_dev != state_dev
# the update is non-linear so indices must be unique # the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(idx, return_inverse=True, return_counts=True) grad_indices, inverse, cnt = th.unique(
idx, return_inverse=True, return_counts=True
)
# update grad state # update grad state
state_idx = grad_indices.to(state_dev) state_idx = grad_indices.to(state_dev)
# The original implementation will cause read/write contension. # The original implementation will cause read/write contension.
...@@ -375,20 +666,23 @@ class SparseAdam(DistSparseGradOptimizer): ...@@ -375,20 +666,23 @@ class SparseAdam(DistSparseGradOptimizer):
# of code will also send read requests to kvstore servers. The write and read requests # of code will also send read requests to kvstore servers. The write and read requests
# may be handled by different kvstore servers managing the same portion of the # may be handled by different kvstore servers managing the same portion of the
# state_step dist tensor in the same node. So that, the read request may read an old # state_step dist tensor in the same node. So that, the read request may read an old
# value (i.e., 0 in the first iteration) which will cause update_power_corr to be NaN # value (i.e., 0 in the first iteration) which will cause
# update_power_corr to be NaN
state_val = state_step[state_idx] + 1 state_val = state_step[state_idx] + 1
state_step[state_idx] = state_val state_step[state_idx] = state_val
state_step = state_val.to(exec_dev) state_step = state_val.to(exec_dev)
orig_mem = state_mem[state_idx].to(exec_dev) orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev) orig_power = state_power[state_idx].to(exec_dev)
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev) grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=exec_dev
)
grad_values.index_add_(0, inverse, grad) grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1) grad_values = grad_values / cnt.unsqueeze(1)
grad_mem = grad_values grad_mem = grad_values
grad_power = grad_values * grad_values grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem
update_power = beta2 * orig_power + (1.-beta2) * grad_power update_power = beta2 * orig_power + (1.0 - beta2) * grad_power
update_mem_dst = update_mem.to(state_dev, non_blocking=True) update_mem_dst = update_mem.to(state_dev, non_blocking=True)
update_power_dst = update_power.to(state_dev, non_blocking=True) update_power_dst = update_power.to(state_dev, non_blocking=True)
if state_block: if state_block:
...@@ -396,10 +690,12 @@ class SparseAdam(DistSparseGradOptimizer): ...@@ -396,10 +690,12 @@ class SparseAdam(DistSparseGradOptimizer):
update_event = th.cuda.Event() update_event = th.cuda.Event()
update_event.record() update_event.record()
update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev), update_mem_corr = update_mem / (
state_step)).unsqueeze(1) 1.0 - th.pow(th.tensor(beta1, device=exec_dev), state_step)
update_power_corr = update_power / (1. - th.pow(th.tensor(beta2, device=exec_dev), ).unsqueeze(1)
state_step)).unsqueeze(1) update_power_corr = update_power / (
1.0 - th.pow(th.tensor(beta2, device=exec_dev), state_step)
).unsqueeze(1)
std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps) std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
std_values_dst = std_values.to(state_dev, non_blocking=True) std_values_dst = std_values.to(state_dev, non_blocking=True)
......
...@@ -62,13 +62,13 @@ def dist_tensor_test_sanity(data_shape, name=None): ...@@ -62,13 +62,13 @@ def dist_tensor_test_sanity(data_shape, name=None):
stride = 3 stride = 3
pos = (part_id // 2) * num_client_per_machine + local_rank pos = (part_id // 2) * num_client_per_machine + local_rank
if part_id % 2 == 0: if part_id % 2 == 0:
dist_ten[pos * stride : (pos + 1) * stride] = F.ones( dist_ten[pos * stride: (pos + 1) * stride] = F.ones(
(stride, 2), dtype=F.int32, ctx=F.cpu() (stride, 2), dtype=F.int32, ctx=F.cpu()
) * (pos + 1) ) * (pos + 1)
dgl.distributed.client_barrier() dgl.distributed.client_barrier()
assert F.allclose( assert F.allclose(
dist_ten[pos * stride : (pos + 1) * stride], dist_ten[pos * stride: (pos + 1) * stride],
F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1), F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),
) )
...@@ -102,7 +102,7 @@ def dist_tensor_test_persistent(data_shape): ...@@ -102,7 +102,7 @@ def dist_tensor_test_persistent(data_shape):
data_shape, F.float32, dist_ten_name data_shape, F.float32, dist_ten_name
) )
raise Exception("") raise Exception("")
except: except BaseException:
pass pass
...@@ -163,7 +163,7 @@ def dist_embedding_check_existing(num_nodes): ...@@ -163,7 +163,7 @@ def dist_embedding_check_existing(num_nodes):
num_nodes, 2, name=dist_emb_name, init_func=zeros_init num_nodes, 2, name=dist_emb_name, init_func=zeros_init
) )
raise Exception("") raise Exception("")
except: except BaseException:
pass pass
...@@ -180,6 +180,59 @@ def test_dist_embedding(g): ...@@ -180,6 +180,59 @@ def test_dist_embedding(g):
dist_embedding_check_existing(num_nodes) dist_embedding_check_existing(num_nodes)
##########################################
############# DistOptimizer ##############
##########################################
def dist_optimizer_check_store(g):
num_nodes = g.number_of_nodes(g.ntypes[0])
rank = g.rank()
try:
emb = dgl.distributed.DistEmbedding(
num_nodes, 1, name="optimizer_test", init_func=zeros_init
)
emb2 = dgl.distributed.DistEmbedding(
num_nodes, 5, name="optimizer_test2", init_func=zeros_init
)
emb_optimizer = dgl.distributed.optim.SparseAdam([emb, emb2], lr=0.1)
if rank == 0:
name_to_state = {}
for _, emb_states in emb_optimizer._state.items():
for state in emb_states:
name_to_state[state.name] = F.uniform(
state.shape, F.float32, F.cpu(), 0, 1
)
state[
F.arange(0, num_nodes, F.int64, F.cpu())
] = name_to_state[state.name]
emb_optimizer.save("emb.pt")
new_emb_optimizer = dgl.distributed.optim.SparseAdam(
[emb, emb2], lr=000.1, eps=2e-08, betas=(0.1, 0.222)
)
new_emb_optimizer.load("emb.pt")
if rank == 0:
for _, emb_states in new_emb_optimizer._state.items():
for new_state in emb_states:
state = name_to_state[new_state.name]
new_state = new_state[
F.arange(0, num_nodes, F.int64, F.cpu())
]
assert F.allclose (state, new_state, 0., 0.)
assert new_emb_optimizer._lr == emb_optimizer._lr
assert new_emb_optimizer._eps == emb_optimizer._eps
assert new_emb_optimizer._beta1 == emb_optimizer._beta1
assert new_emb_optimizer._beta2 == emb_optimizer._beta2
g.barrier()
finally:
file = f'emb.pt_{rank}'
if os.path.exists(file):
os.remove(file)
def test_dist_optimizer(g):
dist_optimizer_check_store(g)
if mode == "server": if mode == "server":
shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM"))) shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM")))
server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID")) server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
...@@ -203,6 +256,7 @@ elif mode == "client": ...@@ -203,6 +256,7 @@ elif mode == "client":
target_func_map = { target_func_map = {
"DistTensor": test_dist_tensor, "DistTensor": test_dist_tensor,
"DistEmbedding": test_dist_embedding, "DistEmbedding": test_dist_embedding,
"DistOptimizer": test_dist_optimizer,
} }
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "") target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
...@@ -213,5 +267,4 @@ elif mode == "client": ...@@ -213,5 +267,4 @@ elif mode == "client":
target_func_map[target](g) target_func_map[target](g)
else: else:
print("DIST_DGL_TEST_MODE has to be either server or client")
exit(1) exit(1)
...@@ -13,6 +13,7 @@ from multiprocessing import Condition, Manager, Process, Value ...@@ -13,6 +13,7 @@ from multiprocessing import Condition, Manager, Process, Value
import backend as F import backend as F
import numpy as np import numpy as np
import pytest import pytest
import torch as th
from numpy.testing import assert_almost_equal, assert_array_equal from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs from utils import create_random_graph, generate_ip_config, reset_envs
...@@ -20,6 +21,7 @@ from utils import create_random_graph, generate_ip_config, reset_envs ...@@ -20,6 +21,7 @@ from utils import create_random_graph, generate_ip_config, reset_envs
import dgl import dgl
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import ( from dgl.distributed import (
DistEmbedding,
DistGraph, DistGraph,
DistGraphServer, DistGraphServer,
edge_split, edge_split,
...@@ -28,6 +30,7 @@ from dgl.distributed import ( ...@@ -28,6 +30,7 @@ from dgl.distributed import (
node_split, node_split,
partition_graph, partition_graph,
) )
from dgl.distributed.optim import SparseAdagrad
from dgl.heterograph_index import create_unitgraph_from_coo from dgl.heterograph_index import create_unitgraph_from_coo
if os.name != "nt": if os.name != "nt":
...@@ -207,6 +210,67 @@ def run_emb_client( ...@@ -207,6 +210,67 @@ def run_emb_client(
check_dist_emb(g, num_clients, num_nodes, num_edges) check_dist_emb(g, num_clients, num_nodes, num_edges)
def run_optim_client(
graph_name,
part_id,
server_count,
rank,
world_size,
num_nodes,
optimizer_states,
save,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
dgl.distributed.initialize("kv_ip_config.txt")
th.distributed.init_process_group(
backend="gloo", rank=rank, world_size=world_size
)
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id, None
)
g = DistGraph(graph_name, gpb=gpb)
check_dist_optim_store(rank, num_nodes, optimizer_states, save)
def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
try:
total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
emb2 = DistEmbedding(
num_nodes, 1, name="optim_emb2", init_func=emb_init
)
if save:
optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
if rank == 0:
optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
optimizer.save("/tmp/dist_graph/emb.pt")
else:
optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
optimizer.load("/tmp/dist_graph/emb.pt")
if rank == 0:
assert F.allclose(
optimizer._state["optim_emb1"][total_idx],
optimizer_states[0],
0.0,
0.0,
)
assert F.allclose(
optimizer._state["optim_emb2"][total_idx],
optimizer_states[1],
0.0,
0.0,
)
assert 0.1 == optimizer._lr
assert 1e-08 == optimizer._eps
th.distributed.barrier()
except Exception as e:
print(e)
sys.exit(-1)
def run_client_hierarchy( def run_client_hierarchy(
graph_name, part_id, server_count, node_mask, edge_mask, return_dict graph_name, part_id, server_count, node_mask, edge_mask, return_dict
): ):
...@@ -233,9 +297,6 @@ def run_client_hierarchy( ...@@ -233,9 +297,6 @@ def run_client_hierarchy(
def check_dist_emb(g, num_clients, num_nodes, num_edges): def check_dist_emb(g, num_clients, num_nodes, num_edges):
from dgl.distributed import DistEmbedding
from dgl.distributed.optim import SparseAdagrad
# Test sparse emb # Test sparse emb
try: try:
emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init) emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
...@@ -845,6 +906,87 @@ def test_dist_emb_server_client(): ...@@ -845,6 +906,87 @@ def test_dist_emb_server_client():
# check_dist_emb_server_client(True, 2, 2, 2) # check_dist_emb_server_client(True, 2, 2, 2)
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
dgl.backend.backend_name == "mxnet",
reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
optimizer_states = []
num_nodes = 10000
optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)
def check_dist_optim_server_client(
num_nodes, num_servers, num_clients, optimizer_states, save
):
graph_name = f"check_dist_optim_{num_servers}_store"
if save:
prepare_dist(num_servers)
g = create_random_graph(num_nodes)
# Partition the graph
num_parts = 1
g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context("spawn")
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
args=(
graph_name,
serv_id,
num_servers,
num_clients,
True,
False,
),
)
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print("start client[{}] for group[0]".format(cli_id))
p = ctx.Process(
target=run_optim_client,
args=(
graph_name,
0,
num_servers,
cli_id,
num_clients,
num_nodes,
optimizer_states,
save,
),
)
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p)
for p in cli_ps:
p.join()
assert p.exitcode == 0
for p in serv_ps:
p.join()
@unittest.skipIf( @unittest.skipIf(
dgl.backend.backend_name == "tensorflow", dgl.backend.backend_name == "tensorflow",
reason="TF doesn't support some of operations in DistGraph", reason="TF doesn't support some of operations in DistGraph",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment