"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b3b2d30cd832bf205819b5d17457bf2f2182b3a7"
Unverified Commit ae4a5b73 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add `state_dict`, `load_state_dict`, `param_groups` to...

[Feature] Add `state_dict`, `load_state_dict`, `param_groups` to `dgl.optim.SparseGradOptimizer` (#5311)

* init update

* all get/set optm_state

* add unit tests

* add docstring

* fix for multiple embeddings

* move embedding methods to private

* fix lint

* fix unit tests

* resolve comments

* merge master
parent 17829024
...@@ -348,6 +348,41 @@ class NodeEmbedding: # NodeEmbedding ...@@ -348,6 +348,41 @@ class NodeEmbedding: # NodeEmbedding
if th.distributed.is_initialized(): if th.distributed.is_initialized():
th.distributed.barrier() th.distributed.barrier()
def _all_get_tensor(self, shared_name, tensor, shape):
"""A helper function to get model-parallel tensors.
This method must and only need to be called in multi-GPU DDP training.
For now, it's only used in ``all_get_embedding`` and
``_all_get_optm_state``.
"""
# create a shared memory tensor
if self._rank == 0:
# root process creates shared memory
val = create_shared_mem_array(
shared_name,
shape,
tensor.dtype,
)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
val = get_shared_mem_array(
shared_name,
shape,
tensor.dtype,
)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(0, tensor.shape[0], ctx=F.context(tensor)),
self._rank,
).to(val.device)
val[idxs] = tensor.to(val.device)
self._store.delete_key(shared_name)
# wait for all processes to finish
th.distributed.barrier()
return val
def all_get_embedding(self): def all_get_embedding(self):
"""Return a copy of the embedding stored in CPU memory. If this is a """Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared multi-processing instance, the tensor will be returned in shared
...@@ -367,35 +402,78 @@ class NodeEmbedding: # NodeEmbedding ...@@ -367,35 +402,78 @@ class NodeEmbedding: # NodeEmbedding
# non-multiprocessing # non-multiprocessing
return self._tensor.to(th.device("cpu")) return self._tensor.to(th.device("cpu"))
else: else:
# create a shared memory tensor return self._all_get_tensor(
shared_name = self._name + "_gather" f"{self._name}_gather",
if self._rank == 0: self._tensor,
# root process creates shared memory (self._num_embeddings, self._embedding_dim),
emb = create_shared_mem_array( )
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
emb = get_shared_mem_array(
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(
0, self._tensor.shape[0], ctx=F.context(self._tensor)
),
self._rank,
).to(emb.device)
emb[idxs] = self._tensor.to(emb.device)
# wait for all processes to finish
th.distributed.barrier()
return emb
else: else:
# already stored in CPU memory # already stored in CPU memory
return self._tensor return self._tensor
def _all_get_optm_state(self):
"""Return a copy of the whole optimizer states stored in CPU memory.
If this is a multi-processing instance, the states will be returned in
shared memory. If the embedding is currently stored on multiple GPUs,
all processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Returns
-------
tuple of torch.Tensor
The optimizer states stored in CPU memory.
"""
if self._partition:
if self._world_size == 0:
# non-multiprocessing
return tuple(
state.to(th.device("cpu")) for state in self._optm_state
)
else:
return tuple(
self._all_get_tensor(
f"state_gather_{self._name}_{i}",
state,
(self._num_embeddings, *state.shape[1:]),
)
for i, state in enumerate(self._optm_state)
)
else:
# already stored in CPU memory
return self._optm_state
def _all_set_optm_state(self, states):
"""Set the optimizer states of the embedding. This method must be
called by all processes sharing the embedding with identical
:attr:`states`.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Parameters
----------
states : tuple of torch.Tensor
The global states to pull values from.
"""
if self._partition:
idxs = F.copy_to(
self._partition.get_local_indices(
max(self._rank, 0), ctx=F.context(self._tensor)
),
F.context(states[0]),
)
for state, new_state in zip(self._optm_state, states):
state[:] = F.copy_to(
F.gather_row(new_state, idxs), ctx=F.context(self._tensor)
)[:]
else:
# stored in CPU memory
if self._rank <= 0:
for state, new_state in zip(self._optm_state, states):
state[:] = F.copy_to(
new_state, ctx=F.context(self._tensor)
)[:]
if th.distributed.is_initialized():
th.distributed.barrier()
...@@ -95,7 +95,6 @@ class SparseGradOptimizer(abc.ABC): ...@@ -95,7 +95,6 @@ class SparseGradOptimizer(abc.ABC):
self._comm_setup() self._comm_setup()
else: else:
self._shared_setup() self._shared_setup()
self.setup(self._params)
self._first_step = False self._first_step = False
if self._comm: if self._comm:
...@@ -103,6 +102,7 @@ class SparseGradOptimizer(abc.ABC): ...@@ -103,6 +102,7 @@ class SparseGradOptimizer(abc.ABC):
else: else:
self._shared_step() self._shared_step()
@abstractmethod
def setup(self, params): def setup(self, params):
"""This is function where subclasses can perform any setup they need """This is function where subclasses can perform any setup they need
to. It will be called during the first step, and communicators or to. It will be called during the first step, and communicators or
...@@ -452,6 +452,59 @@ class SparseGradOptimizer(abc.ABC): ...@@ -452,6 +452,59 @@ class SparseGradOptimizer(abc.ABC):
"""clean grad cache""" """clean grad cache"""
self._clean_grad = True self._clean_grad = True
def state_dict(self, **kwargs): # pylint: disable=unused-argument
"""Return a copy of the whole optimizer states stored in CPU memory.
If this is a multi-processing instance, the states will be returned in
shared memory. If the underlying embedding is currently stored on
multiple GPUs, all processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Returns
-------
dictionary of optimizer states
The optimizer states stored in CPU memory.
"""
return {
"state": {
emb.name: emb._all_get_optm_state() for emb in self._params
},
"param_groups": self.param_groups,
}
def load_state_dict(
self, state_dict, **kwargs
): # pylint: disable=unused-argument
"""Load the optimizer states. This method must be called by all
processes sharing the underlying embedding with identical
:attr:`state_dict`.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Parameters
----------
state_dict : dictionary of optimizer states
The global states to pull values from.
"""
for emb in self._params:
emb._all_set_optm_state(state_dict["state"][emb.name])
self._set_param_groups(state_dict["param_groups"])
@property
@abstractmethod
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""
@abstractmethod
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
class SparseAdagrad(SparseGradOptimizer): class SparseAdagrad(SparseGradOptimizer):
r"""Node embedding optimizer using the Adagrad algorithm. r"""Node embedding optimizer using the Adagrad algorithm.
...@@ -496,6 +549,9 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -496,6 +549,9 @@ class SparseAdagrad(SparseGradOptimizer):
super(SparseAdagrad, self).__init__(params, lr) super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps self._eps = eps
# setup tensors for optimizer states
self.setup(self._params)
def setup(self, params): def setup(self, params):
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
for emb in params: for emb in params:
...@@ -532,7 +588,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -532,7 +588,7 @@ class SparseAdagrad(SparseGradOptimizer):
dtype=th.float32, dtype=th.float32,
device=emb.weight.device, device=emb.weight.device,
).zero_() ).zero_()
emb.set_optm_state(state) emb.set_optm_state((state,))
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner """Update embeddings in a sparse manner
...@@ -562,7 +618,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -562,7 +618,7 @@ class SparseAdagrad(SparseGradOptimizer):
grad_values = grad_values / cnt.unsqueeze(1) grad_values = grad_values / cnt.unsqueeze(1)
grad_sum = grad_values * grad_values grad_sum = grad_values * grad_values
state = emb.optm_state (state,) = emb.optm_state
state_dev = state.device state_dev = state.device
state_idx = grad_indices.to(state_dev) state_idx = grad_indices.to(state_dev)
grad_state = state[state_idx].to(grad.device) grad_state = state[state_idx].to(grad.device)
...@@ -573,6 +629,20 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -573,6 +629,20 @@ class SparseAdagrad(SparseGradOptimizer):
tmp = clr * grad_values / std_values tmp = clr * grad_values / std_values
emb.weight[state_idx] -= tmp.to(state_dev) emb.weight[state_idx] -= tmp.to(state_dev)
@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""
return [{"lr": self._lr, "eps": self._eps}]
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._eps = groups[0]["eps"]
class SparseAdam(SparseGradOptimizer): class SparseAdam(SparseGradOptimizer):
r"""Node embedding optimizer using the Adam algorithm. r"""Node embedding optimizer using the Adam algorithm.
...@@ -653,6 +723,9 @@ class SparseAdam(SparseGradOptimizer): ...@@ -653,6 +723,9 @@ class SparseAdam(SparseGradOptimizer):
) )
self._dtype = dtype self._dtype = dtype
# setup tensors for optimizer states
self.setup(self._params)
def _setup_uva(self, name, mem, power): def _setup_uva(self, name, mem, power):
self._is_using_uva[name] = True self._is_using_uva[name] = True
mem_nd = pin_memory_inplace(mem) mem_nd = pin_memory_inplace(mem)
...@@ -863,3 +936,24 @@ class SparseAdam(SparseGradOptimizer): ...@@ -863,3 +936,24 @@ class SparseAdam(SparseGradOptimizer):
# can use it # can use it
std_event.wait() std_event.wait()
emb.weight[state_idx] -= std_values_dst emb.weight[state_idx] -= std_values_dst
@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, betas, eps, for debugging.
"""
return [
{
"lr": self._lr,
"betas": (self._beta1, self._beta2),
"eps": self._eps,
}
]
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._beta1, self._beta2 = groups[0]["betas"]
self._eps = groups[0]["eps"]
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch as th import torch as th
from dgl.nn import NodeEmbedding from dgl.nn import NodeEmbedding
from dgl.optim import SparseAdam
def initializer(emb): def initializer(emb):
...@@ -15,7 +16,7 @@ def initializer(emb): ...@@ -15,7 +16,7 @@ def initializer(emb):
return emb return emb
def check_all_set_all_get_func(device, init_emb): def check_all_set_all_get_emb(device, init_emb):
num_embs = init_emb.shape[0] num_embs = init_emb.shape[0]
emb_dim = init_emb.shape[1] emb_dim = init_emb.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device) dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
...@@ -25,6 +26,23 @@ def check_all_set_all_get_func(device, init_emb): ...@@ -25,6 +26,23 @@ def check_all_set_all_get_func(device, init_emb):
assert F.allclose(init_emb, out_emb) assert F.allclose(init_emb, out_emb)
def check_all_set_all_get_optm_state(
device, state_step, state_mem, state_power
):
num_embs = state_mem.shape[0]
emb_dim = state_mem.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
optm = SparseAdam(params=[dgl_emb], lr=0.01)
dgl_emb._all_set_optm_state((state_step, state_mem, state_power))
out_step, out_mem, out_power = dgl_emb._all_get_optm_state()
assert F.allclose(state_step, out_step)
assert F.allclose(state_mem, out_mem)
assert F.allclose(state_power, out_power)
def start_sparse_worker(rank, world_size, test, args): def start_sparse_worker(rank, world_size, test, args):
print("start sparse worker {}".format(rank)) print("start sparse worker {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
...@@ -44,6 +62,7 @@ def start_sparse_worker(rank, world_size, test, args): ...@@ -44,6 +62,7 @@ def start_sparse_worker(rank, world_size, test, args):
test(device, *args) test(device, *args)
th.distributed.barrier() th.distributed.barrier()
th.distributed.destroy_process_group()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
...@@ -60,7 +79,40 @@ def test_multiprocess_sparse_emb_get_set(num_workers): ...@@ -60,7 +79,40 @@ def test_multiprocess_sparse_emb_get_set(num_workers):
for i in range(num_workers): for i in range(num_workers):
p = ctx.Process( p = ctx.Process(
target=start_sparse_worker, target=start_sparse_worker,
args=(i, num_workers, check_all_set_all_get_func, (init_emb,)), args=(i, num_workers, check_all_set_all_get_emb, (init_emb,)),
)
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
for p in worker_list:
assert p.exitcode == 0
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set_optm_state(num_workers):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = []
num_embs, emb_dim = 1000, 8
state_step = th.randint(1000, (num_embs,))
state_mem = th.rand((num_embs, emb_dim))
state_power = th.rand((num_embs, emb_dim))
ctx = mp.get_context("spawn")
for i in range(num_workers):
p = ctx.Process(
target=start_sparse_worker,
args=(
i,
num_workers,
check_all_set_all_get_optm_state,
(state_step, state_mem, state_power),
),
) )
p.start() p.start()
worker_list.append(p) worker_list.append(p)
...@@ -72,6 +124,10 @@ def test_multiprocess_sparse_emb_get_set(num_workers): ...@@ -72,6 +124,10 @@ def test_multiprocess_sparse_emb_get_set(num_workers):
if __name__ == "__main__": if __name__ == "__main__":
test_sparse_emb_get_set(1) # test_multiprocess_sparse_emb_get_set(1)
test_sparse_emb_get_set(2) # test_multiprocess_sparse_emb_get_set(2)
test_sparse_emb_get_set(3) # test_multiprocess_sparse_emb_get_set(3)
test_multiprocess_sparse_emb_get_set_optm_state(1)
# test_multiprocess_sparse_emb_get_set_optm_state(2)
# test_multiprocess_sparse_emb_get_set_optm_state(3)
import os import os
import time
import unittest import unittest
import backend as F import backend as F
...@@ -590,6 +589,86 @@ def test_multiprocess_sparse_adam_zero_step_cuda_tensor(num_workers): ...@@ -590,6 +589,86 @@ def test_multiprocess_sparse_adam_zero_step_cuda_tensor(num_workers):
assert F.allclose(dgl_weight, torch_weight) assert F.allclose(dgl_weight, torch_weight)
def start_sparse_adam_state_dict_worker(
rank,
world_size,
init_weight,
backend,
num_embs,
emb_dim,
):
print("start sparse worker for adam {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip="127.0.0.1", master_port="12345"
)
device = th.device(f"cuda:{rank}")
th.cuda.set_device(device)
tensor_dev = device if backend == "nccl" else th.device("cpu")
th.distributed.init_process_group(
backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=rank,
)
th.manual_seed(0)
dgl_emb = NodeEmbedding(
num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
)
dgl_emb.all_set_embedding(init_weight)
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
th.manual_seed(rank)
idx = th.randint(start, end, size=(4,)).to(tensor_dev)
dgl_value = dgl_emb(idx, device)
labels = th.ones((4,)).long().to(device)
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
dgl_adam.zero_grad()
dgl_loss.backward()
dgl_adam.step()
th.distributed.barrier()
worker_state_dict = [t.detach().clone() for t in dgl_emb.optm_state]
state_dict = dgl_adam.state_dict()
for t in dgl_emb.optm_state:
t.zero_()
dgl_adam.load_state_dict(state_dict)
for i, j in zip(worker_state_dict, dgl_emb.optm_state):
F.allclose(i, j)
th.distributed.barrier()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
@pytest.mark.parametrize("num_workers", [1, 2, 4, 8])
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
def test_multiprocess_sparse_adam_state_dict(num_workers, backend):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
num_embs = 128
emb_dim = 10
init_weight = th.rand((num_embs, emb_dim))
mp.spawn(
start_sparse_adam_state_dict_worker,
(
num_workers,
init_weight,
backend,
num_embs,
emb_dim,
),
nprocs=num_workers,
)
if __name__ == "__main__": if __name__ == "__main__":
test_sparse_adam(1) test_sparse_adam(1)
test_sparse_adam(4) test_sparse_adam(4)
...@@ -614,3 +693,6 @@ if __name__ == "__main__": ...@@ -614,3 +693,6 @@ if __name__ == "__main__":
test_multiprocess_sparse_adam_cuda_tensor(2) test_multiprocess_sparse_adam_cuda_tensor(2)
test_multiprocess_sparse_adam_zero_step_cuda_tensor(4) test_multiprocess_sparse_adam_zero_step_cuda_tensor(4)
test_multiprocess_sparse_adam_state_dict(2, "nccl")
test_multiprocess_sparse_adam_state_dict(2, "gloo")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment