import dgl import os import numpy as np import dgl.backend as F from dgl.distributed import load_partition_book mode = os.environ.get('DIST_DGL_TEST_MODE', "") graph_name = os.environ.get('DIST_DGL_TEST_GRAPH_NAME', 'random_test_graph') num_part = int(os.environ.get('DIST_DGL_TEST_NUM_PART')) num_servers_per_machine = int(os.environ.get('DIST_DGL_TEST_NUM_SERVER')) num_client_per_machine = int(os.environ.get('DIST_DGL_TEST_NUM_CLIENT')) shared_workspace = os.environ.get('DIST_DGL_TEST_WORKSPACE') graph_path = os.environ.get('DIST_DGL_TEST_GRAPH_PATH') part_id = int(os.environ.get('DIST_DGL_TEST_PART_ID')) net_type = os.environ.get('DIST_DGL_TEST_NET_TYPE') ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG', 'ip_config.txt') os.environ['DGL_DIST_MODE'] = 'distributed' def zeros_init(shape, dtype): return F.zeros(shape, dtype=dtype, ctx=F.cpu()) def run_server(graph_name, server_id, server_count, num_clients, shared_mem, keep_alive=False): # server_count = num_servers_per_machine g = dgl.distributed.DistGraphServer(server_id, ip_config, server_count, num_clients, graph_path + '/{}.json'.format(graph_name), disable_shared_mem=not shared_mem, graph_format=['csc', 'coo'], keep_alive=keep_alive, net_type=net_type) print('start server', server_id) g.start() ########################################## ############### DistTensor ############### ########################################## def dist_tensor_test_sanity(data_shape, name=None): local_rank = dgl.distributed.get_rank() % num_client_per_machine dist_ten = dgl.distributed.DistTensor(data_shape, F.int32, init_func=zeros_init, name=name) # arbitrary value stride = 3 pos = (part_id // 2) * num_client_per_machine + local_rank if part_id % 2 == 0: dist_ten[pos*stride:(pos+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1) dgl.distributed.client_barrier() assert F.allclose(dist_ten[pos*stride:(pos+1)*stride], F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1)) def dist_tensor_test_destroy_recreate(data_shape, name): dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, name, init_func=zeros_init) del dist_ten dgl.distributed.client_barrier() new_shape = (data_shape[0], 4) dist_ten = dgl.distributed.DistTensor(new_shape, F.float32, name, init_func=zeros_init) def dist_tensor_test_persistent(data_shape): dist_ten_name = 'persistent_dist_tensor' dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, dist_ten_name, init_func=zeros_init, persistent=True) del dist_ten try: dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, dist_ten_name) raise Exception('') except: pass def test_dist_tensor(g): first_type = g.ntypes[0] data_shape = (g.number_of_nodes(first_type), 2) dist_tensor_test_sanity(data_shape) dist_tensor_test_sanity(data_shape, name="DistTensorSanity") dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate") dist_tensor_test_persistent(data_shape) ########################################## ############# DistEmbedding ############## ########################################## def dist_embedding_check_sanity(num_nodes, optimizer, name=None): local_rank = dgl.distributed.get_rank() % num_client_per_machine emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=name, init_func=zeros_init) lr = 0.001 optim = optimizer(params=[emb], lr=lr) stride = 3 pos = (part_id // 2) * num_client_per_machine + local_rank idx = F.arange(pos*stride, (pos+1)*stride) if part_id % 2 == 0: with F.record_grad(): value = emb(idx) optim.zero_grad() loss = F.sum(value + 1, 0) loss.backward() optim.step() dgl.distributed.client_barrier() value = emb(idx) F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr) not_update_idx = F.arange(((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes) value = emb(not_update_idx) assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1))) def dist_embedding_check_existing(num_nodes): dist_emb_name = "UniqueEmb" emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=dist_emb_name, init_func=zeros_init) try: emb1 = dgl.distributed.DistEmbedding(num_nodes, 2, name=dist_emb_name, init_func=zeros_init) raise Exception('') except: pass def test_dist_embedding(g): num_nodes = g.number_of_nodes(g.ntypes[0]) dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad) dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad, name='SomeEmbedding') dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdam, name='SomeEmbedding') dist_embedding_check_existing(num_nodes) if mode == "server": shared_mem = bool(int(os.environ.get('DIST_DGL_TEST_SHARED_MEM'))) server_id = int(os.environ.get('DIST_DGL_TEST_SERVER_ID')) run_server(graph_name, server_id, server_count=num_servers_per_machine, num_clients=num_part*num_client_per_machine, shared_mem=shared_mem, keep_alive=False) elif mode == "client": os.environ['DGL_NUM_SERVER'] = str(num_servers_per_machine) dgl.distributed.initialize(ip_config, net_type=net_type) gpb, graph_name, _, _ = load_partition_book(graph_path + '/{}.json'.format(graph_name), part_id, None) g = dgl.distributed.DistGraph(graph_name, gpb=gpb) target_func_map = {"DistTensor": test_dist_tensor, "DistEmbedding": test_dist_embedding, } target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "") if target not in target_func_map: for test_func in target_func_map.values(): test_func(g) else: target_func_map[target](g) else: print("DIST_DGL_TEST_MODE has to be either server or client") exit(1)