Unverified Commit de5e8e23 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a bug for graphs without node/edge data. (#2838)

* fix.

* test distributed graph without node/edge data.

* remove some tests.

* fix lint
parent afc83aa2
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import numpy as np import numpy as np
from . import rpc from . import rpc
from .graph_partition_book import PartitionPolicy from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy
from .standalone_kvstore import KVClient as SA_KVClient from .standalone_kvstore import KVClient as SA_KVClient
from .. import backend as F from .. import backend as F
...@@ -365,8 +365,6 @@ class GetSharedDataRequest(rpc.Request): ...@@ -365,8 +365,6 @@ class GetSharedDataRequest(rpc.Request):
meta[name] = (F.shape(data), meta[name] = (F.shape(data),
F.reverse_data_type_dict[F.dtype(data)], F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str) kv_store.part_policy[name].policy_str)
if len(meta) == 0:
raise RuntimeError('There is no data on kvserver.')
res = GetSharedDataResponse(meta) res = GetSharedDataResponse(meta)
return res return res
...@@ -1058,6 +1056,14 @@ class KVClient(object): ...@@ -1058,6 +1056,14 @@ class KVClient(object):
partition_book : GraphPartitionBook partition_book : GraphPartitionBook
Store the partition information Store the partition information
""" """
# Get all partition policies
for ntype in partition_book.ntypes:
policy = NodePartitionPolicy(partition_book, ntype)
self._all_possible_part_policy[policy.policy_str] = policy
for etype in partition_book.etypes:
policy = EdgePartitionPolicy(partition_book, etype)
self._all_possible_part_policy[policy.policy_str] = policy
# Get shared data from server side # Get shared data from server side
self.barrier() self.barrier()
request = GetSharedDataRequest(GET_SHARED_MSG) request = GetSharedDataRequest(GET_SHARED_MSG)
...@@ -1066,11 +1072,11 @@ class KVClient(object): ...@@ -1066,11 +1072,11 @@ class KVClient(object):
for name, meta in response.meta.items(): for name, meta in response.meta.items():
if name not in self._data_name_list: if name not in self._data_name_list:
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
assert policy_str in self._all_possible_part_policy
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype) shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, partition_book) self._part_policy[name] = self._all_possible_part_policy[policy_str]
self._all_possible_part_policy[policy_str] = self._part_policy[name]
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Get full data shape across servers # Get full data shape across servers
......
...@@ -66,6 +66,79 @@ def emb_init(shape, dtype): ...@@ -66,6 +66,79 @@ def emb_init(shape, dtype):
def rand_init(shape, dtype): def rand_init(shape, dtype):
return F.tensor(np.random.normal(size=shape), F.float32) return F.tensor(np.random.normal(size=shape), F.float32)
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
# Test API
assert g.number_of_nodes() == num_nodes
assert g.number_of_edges() == num_edges
# Test init node data
new_shape = (g.number_of_nodes(), 2)
g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
nids = F.arange(0, int(g.number_of_nodes() / 2))
feats = g.ndata['test1'][nids]
assert np.all(F.asnumpy(feats) == 0)
# create a tensor and destroy a tensor and create it again.
test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
del test3
test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
del test3
# Test write data
new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
g.ndata['test1'][nids] = new_feats
feats = g.ndata['test1'][nids]
assert np.all(F.asnumpy(feats) == 1)
# Test metadata operations.
assert g.node_attr_schemes()['test1'].dtype == F.int32
print('end')
def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
def check_server_client_empty(shared_mem, num_servers, num_clients):
prepare_dist()
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_1'
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
p = ctx.Process(target=run_client_empty, args=(graph_name, 0, num_servers, num_clients,
g.number_of_nodes(), g.number_of_edges()))
p.start()
cli_ps.append(p)
for p in cli_ps:
p.join()
for p in serv_ps:
p.join()
print('clients have terminated')
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
...@@ -380,6 +453,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -380,6 +453,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client(): def test_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client_empty(True, 1, 1)
check_server_client_hetero(True, 1, 1) check_server_client_hetero(True, 1, 1)
check_server_client_hetero(False, 1, 1) check_server_client_hetero(False, 1, 1)
check_server_client(True, 1, 1) check_server_client(True, 1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment