Unverified Commit 2caa6bd0 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Dist] Enable save and load for Distributed Optimizer (#4752)



* add save/load for distributed  optimizer
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 5833efe0
......@@ -48,10 +48,10 @@ Distributed embedding optimizer
-------------------------
.. autoclass:: dgl.distributed.optim.SparseAdagrad
:members: step
:members: step, save, load
.. autoclass:: dgl.distributed.optim.SparseAdam
:members: step
:members: step, save, load
Distributed workload split
--------------------------
......
......@@ -62,13 +62,13 @@ def dist_tensor_test_sanity(data_shape, name=None):
stride = 3
pos = (part_id // 2) * num_client_per_machine + local_rank
if part_id % 2 == 0:
dist_ten[pos * stride : (pos + 1) * stride] = F.ones(
dist_ten[pos * stride: (pos + 1) * stride] = F.ones(
(stride, 2), dtype=F.int32, ctx=F.cpu()
) * (pos + 1)
dgl.distributed.client_barrier()
assert F.allclose(
dist_ten[pos * stride : (pos + 1) * stride],
dist_ten[pos * stride: (pos + 1) * stride],
F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),
)
......@@ -102,7 +102,7 @@ def dist_tensor_test_persistent(data_shape):
data_shape, F.float32, dist_ten_name
)
raise Exception("")
except:
except BaseException:
pass
......@@ -163,7 +163,7 @@ def dist_embedding_check_existing(num_nodes):
num_nodes, 2, name=dist_emb_name, init_func=zeros_init
)
raise Exception("")
except:
except BaseException:
pass
......@@ -180,6 +180,59 @@ def test_dist_embedding(g):
dist_embedding_check_existing(num_nodes)
##########################################
############# DistOptimizer ##############
##########################################
def dist_optimizer_check_store(g):
num_nodes = g.number_of_nodes(g.ntypes[0])
rank = g.rank()
try:
emb = dgl.distributed.DistEmbedding(
num_nodes, 1, name="optimizer_test", init_func=zeros_init
)
emb2 = dgl.distributed.DistEmbedding(
num_nodes, 5, name="optimizer_test2", init_func=zeros_init
)
emb_optimizer = dgl.distributed.optim.SparseAdam([emb, emb2], lr=0.1)
if rank == 0:
name_to_state = {}
for _, emb_states in emb_optimizer._state.items():
for state in emb_states:
name_to_state[state.name] = F.uniform(
state.shape, F.float32, F.cpu(), 0, 1
)
state[
F.arange(0, num_nodes, F.int64, F.cpu())
] = name_to_state[state.name]
emb_optimizer.save("emb.pt")
new_emb_optimizer = dgl.distributed.optim.SparseAdam(
[emb, emb2], lr=000.1, eps=2e-08, betas=(0.1, 0.222)
)
new_emb_optimizer.load("emb.pt")
if rank == 0:
for _, emb_states in new_emb_optimizer._state.items():
for new_state in emb_states:
state = name_to_state[new_state.name]
new_state = new_state[
F.arange(0, num_nodes, F.int64, F.cpu())
]
assert F.allclose (state, new_state, 0., 0.)
assert new_emb_optimizer._lr == emb_optimizer._lr
assert new_emb_optimizer._eps == emb_optimizer._eps
assert new_emb_optimizer._beta1 == emb_optimizer._beta1
assert new_emb_optimizer._beta2 == emb_optimizer._beta2
g.barrier()
finally:
file = f'emb.pt_{rank}'
if os.path.exists(file):
os.remove(file)
def test_dist_optimizer(g):
dist_optimizer_check_store(g)
if mode == "server":
shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM")))
server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
......@@ -203,6 +256,7 @@ elif mode == "client":
target_func_map = {
"DistTensor": test_dist_tensor,
"DistEmbedding": test_dist_embedding,
"DistOptimizer": test_dist_optimizer,
}
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
......@@ -213,5 +267,4 @@ elif mode == "client":
target_func_map[target](g)
else:
print("DIST_DGL_TEST_MODE has to be either server or client")
exit(1)
......@@ -13,6 +13,7 @@ from multiprocessing import Condition, Manager, Process, Value
import backend as F
import numpy as np
import pytest
import torch as th
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
......@@ -20,6 +21,7 @@ from utils import create_random_graph, generate_ip_config, reset_envs
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
DistEmbedding,
DistGraph,
DistGraphServer,
edge_split,
......@@ -28,6 +30,7 @@ from dgl.distributed import (
node_split,
partition_graph,
)
from dgl.distributed.optim import SparseAdagrad
from dgl.heterograph_index import create_unitgraph_from_coo
if os.name != "nt":
......@@ -207,6 +210,67 @@ def run_emb_client(
check_dist_emb(g, num_clients, num_nodes, num_edges)
def run_optim_client(
graph_name,
part_id,
server_count,
rank,
world_size,
num_nodes,
optimizer_states,
save,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
dgl.distributed.initialize("kv_ip_config.txt")
th.distributed.init_process_group(
backend="gloo", rank=rank, world_size=world_size
)
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id, None
)
g = DistGraph(graph_name, gpb=gpb)
check_dist_optim_store(rank, num_nodes, optimizer_states, save)
def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
try:
total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
emb2 = DistEmbedding(
num_nodes, 1, name="optim_emb2", init_func=emb_init
)
if save:
optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
if rank == 0:
optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
optimizer.save("/tmp/dist_graph/emb.pt")
else:
optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
optimizer.load("/tmp/dist_graph/emb.pt")
if rank == 0:
assert F.allclose(
optimizer._state["optim_emb1"][total_idx],
optimizer_states[0],
0.0,
0.0,
)
assert F.allclose(
optimizer._state["optim_emb2"][total_idx],
optimizer_states[1],
0.0,
0.0,
)
assert 0.1 == optimizer._lr
assert 1e-08 == optimizer._eps
th.distributed.barrier()
except Exception as e:
print(e)
sys.exit(-1)
def run_client_hierarchy(
graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
......@@ -233,9 +297,6 @@ def run_client_hierarchy(
def check_dist_emb(g, num_clients, num_nodes, num_edges):
from dgl.distributed import DistEmbedding
from dgl.distributed.optim import SparseAdagrad
# Test sparse emb
try:
emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
......@@ -845,6 +906,87 @@ def test_dist_emb_server_client():
# check_dist_emb_server_client(True, 2, 2, 2)
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
dgl.backend.backend_name == "mxnet",
reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
optimizer_states = []
num_nodes = 10000
optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)
def check_dist_optim_server_client(
num_nodes, num_servers, num_clients, optimizer_states, save
):
graph_name = f"check_dist_optim_{num_servers}_store"
if save:
prepare_dist(num_servers)
g = create_random_graph(num_nodes)
# Partition the graph
num_parts = 1
g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context("spawn")
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
args=(
graph_name,
serv_id,
num_servers,
num_clients,
True,
False,
),
)
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print("start client[{}] for group[0]".format(cli_id))
p = ctx.Process(
target=run_optim_client,
args=(
graph_name,
0,
num_servers,
cli_id,
num_clients,
num_nodes,
optimizer_states,
save,
),
)
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p)
for p in cli_ps:
p.join()
assert p.exitcode == 0
for p in serv_ps:
p.join()
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="TF doesn't support some of operations in DistGraph",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment