import os os.environ["OMP_NUM_THREADS"] = "1" import multiprocessing as mp import pickle import random import socket import sys import time import unittest import backend as F import numpy as np import torch as th from scipy import sparse as spsp import dgl from dgl import function as fn from dgl.distributed import ( DistEmbedding, DistGraph, DistGraphServer, load_partition_book, partition_graph, ) from dgl.distributed.optim import SparseAdagrad, SparseAdam def create_random_graph(n): arr = ( spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0 ).astype(np.int64) return dgl.from_scipy(arr) def get_local_usable_addr(): """Get local usable IP and port Returns ------- str IP address, e.g., '192.168.8.12:50051' """ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable sock.connect(("10.255.255.255", 1)) ip_addr = sock.getsockname()[0] except ValueError: ip_addr = "127.0.0.1" finally: sock.close() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) sock.listen(1) port = sock.getsockname()[1] sock.close() return ip_addr + " " + str(port) def prepare_dist(): ip_config = open("optim_ip_config.txt", "w") ip_addr = get_local_usable_addr() ip_config.write("{}\n".format(ip_addr)) ip_config.close() def run_server(graph_name, server_id, server_count, num_clients, shared_mem): g = DistGraphServer( server_id, "optim_ip_config.txt", num_clients, server_count, "/tmp/dist_graph/{}.json".format(graph_name), disable_shared_mem=not shared_mem, ) print("start server", server_id) g.start() def initializer(shape, dtype): arr = th.zeros(shape, dtype=dtype) th.manual_seed(0) th.nn.init.uniform_(arr, 0, 1.0) return arr def run_client(graph_name, cli_id, part_id, server_count): device = F.ctx() time.sleep(5) os.environ["DGL_NUM_SERVER"] = str(server_count) dgl.distributed.initialize("optim_ip_config.txt") gpb, graph_name, _, _ = load_partition_book( "/tmp/dist_graph/{}.json".format(graph_name), part_id, None ) g = DistGraph(graph_name, gpb=gpb) policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book()) num_nodes = g.number_of_nodes() emb_dim = 4 dgl_emb = DistEmbedding( num_nodes, emb_dim, name="optim", init_func=initializer, part_policy=policy, ) dgl_emb_zero = DistEmbedding( num_nodes, emb_dim, name="optim-zero", init_func=initializer, part_policy=policy, ) dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01) dgl_adam._world_size = 1 dgl_adam._rank = 0 torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True) torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True) th.manual_seed(0) th.nn.init.uniform_(torch_emb.weight, 0, 1.0) th.manual_seed(0) th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0) torch_adam = th.optim.SparseAdam( list(torch_emb.parameters()) + list(torch_emb_zero.parameters()), lr=0.01, ) labels = th.ones((4,)).long() idx = th.randint(0, num_nodes, size=(4,)) dgl_value = dgl_emb(idx, device).to(th.device("cpu")) torch_value = torch_emb(idx) torch_adam.zero_grad() torch_loss = th.nn.functional.cross_entropy(torch_value, labels) torch_loss.backward() torch_adam.step() dgl_adam.zero_grad() dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) dgl_loss.backward() dgl_adam.step() assert F.allclose( dgl_emb.weight[0 : num_nodes // 2], torch_emb.weight[0 : num_nodes // 2] ) def check_sparse_adam(num_trainer=1, shared_mem=True): prepare_dist() g = create_random_graph(2000) num_servers = num_trainer num_clients = num_trainer num_parts = 1 graph_name = "dist_graph_test" partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") # let's just test on one partition for now. # We cannot run multiple servers and clients on the same machine. serv_ps = [] ctx = mp.get_context("spawn") for serv_id in range(num_servers): p = ctx.Process( target=run_server, args=(graph_name, serv_id, num_servers, num_clients, shared_mem), ) serv_ps.append(p) p.start() cli_ps = [] for cli_id in range(num_clients): print("start client", cli_id) p = ctx.Process( target=run_client, args=(graph_name, cli_id, 0, num_servers) ) p.start() cli_ps.append(p) for p in cli_ps: p.join() for p in serv_ps: p.join() @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") def test_sparse_opt(): os.environ["DGL_DIST_MODE"] = "distributed" check_sparse_adam(1, True) check_sparse_adam(1, False) if __name__ == "__main__": os.makedirs("/tmp/dist_graph", exist_ok=True) test_sparse_opt()