Unverified Commit a4c931a9 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Support multiple servers (#1886)



* client init graph on the backup servers.

* fix.

* test multi-server.

* fix anonymous dist tensors.

* check #parts.

* fix init_data

* add multi-server multi-client tests.

* update tests in kvstore.

* fix.

* verify the loaded partition.

* fix a bug.

* fix lint.

* fix.

* fix example.

* fix rpc.

* fix pull/push handler for backup kvstore

* fix example readme.

* change ip.

* update docstring.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-1.us-west-2.compute.internal>
parent 9bcce7be
...@@ -55,7 +55,7 @@ To run unsupervised training: ...@@ -55,7 +55,7 @@ To run unsupervised training:
python3 ~/dgl/tools/launch.py \ python3 ~/dgl/tools/launch.py \
--workspace ~/dgl/examples/pytorch/graphsage/experimental \ --workspace ~/dgl/examples/pytorch/graphsage/experimental \
--num_client 4 \ --num_client 4 \
--conf_path data/ogb-product.json \ --part_config data/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --num-client 4" "python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --num-client 4"
``` ```
...@@ -76,13 +76,13 @@ python3 partition_graph.py --dataset ogb-product --num_parts 1 ...@@ -76,13 +76,13 @@ python3 partition_graph.py --dataset ogb-product --num_parts 1
To run supervised training: To run supervised training:
```bash ```bash
python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --part_config data/ogb-product.json --standalone
``` ```
To run unsupervised training: To run unsupervised training:
```bash ```bash
python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --part_config data/ogb-product.json --standalone
``` ```
Note: please ensure that all environment variables shown above are unset if they were set for testing distributed training. Note: please ensure that all environment variables shown above are unset if they were set for testing distributed training.
172.31.16.250 5555 1 172.31.19.1 5555 2
172.31.30.135 5555 1 172.31.23.205 5555 2
172.31.27.41 5555 1 172.31.29.175 5555 2
172.31.30.149 5555 1 172.31.16.98 5555 2
...@@ -349,7 +349,7 @@ def run(args, device, data): ...@@ -349,7 +349,7 @@ def run(args, device, data):
def main(args): def main(args):
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, conf_file=args.conf_path) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path)
print('rank:', g.rank()) print('rank:', g.rank())
print('number of edges', g.number_of_edges()) print('number of edges', g.number_of_edges())
......
...@@ -12,7 +12,7 @@ from .kvstore import KVServer, KVClient ...@@ -12,7 +12,7 @@ from .kvstore import KVServer, KVClient
from .standalone_kvstore import KVClient as SA_KVClient from .standalone_kvstore import KVClient as SA_KVClient
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme from ..frame import infer_scheme
from .partition import load_partition from .partition import load_partition, load_partition_book
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
...@@ -23,6 +23,41 @@ from .rpc_server import start_server ...@@ -23,6 +23,41 @@ from .rpc_server import start_server
from .graph_services import find_edges as dist_find_edges from .graph_services import find_edges as dist_find_edges
from .dist_tensor import DistTensor, _get_data_name from .dist_tensor import DistTensor, _get_data_name
INIT_GRAPH = 800001
class InitGraphRequest(rpc.Request):
""" Init graph on the backup servers.
When the backup server starts, they don't load the graph structure.
This request tells the backup servers that they can map to the graph structure
with shared memory.
"""
def __init__(self, graph_name):
self._graph_name = graph_name
def __getstate__(self):
return self._graph_name
def __setstate__(self, state):
self._graph_name = state
def process_request(self, server_state):
if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem(self._graph_name)
return InitGraphResponse(self._graph_name)
class InitGraphResponse(rpc.Response):
""" Ack the init graph request
"""
def __init__(self, graph_name):
self._graph_name = graph_name
def __getstate__(self):
return self._graph_name
def __setstate__(self, state):
self._graph_name = state
def _copy_graph_to_shared_mem(g, graph_name): def _copy_graph_to_shared_mem(g, graph_name):
new_g = g.shared_memory(graph_name, formats='csc') new_g = g.shared_memory(graph_name, formats='csc')
# We should share the node/edge data to the client explicitly instead of putting them # We should share the node/edge data to the client explicitly instead of putting them
...@@ -218,16 +253,20 @@ class DistGraphServer(KVServer): ...@@ -218,16 +253,20 @@ class DistGraphServer(KVServer):
num_clients=num_clients) num_clients=num_clients)
self.ip_config = ip_config self.ip_config = ip_config
# Load graph partition data. # Load graph partition data.
self.client_g, node_feats, edge_feats, self.gpb, graph_name = load_partition(part_config, if self.is_backup_server():
server_id) # The backup server doesn't load the graph partition. It'll initialized afterwards.
self.gpb, graph_name = load_partition_book(part_config, self.part_id)
self.client_g = None
else:
self.client_g, node_feats, edge_feats, self.gpb, \
graph_name = load_partition(part_config, self.part_id)
print('load ' + graph_name) print('load ' + graph_name)
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name) self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
# Init kvstore.
if not disable_shared_mem: if not disable_shared_mem:
self.gpb.shared_memory(graph_name) self.gpb.shared_memory(graph_name)
assert self.gpb.partid == server_id assert self.gpb.partid == self.part_id
self.add_part_policy(PartitionPolicy(NODE_PART_POLICY, self.gpb)) self.add_part_policy(PartitionPolicy(NODE_PART_POLICY, self.gpb))
self.add_part_policy(PartitionPolicy(EDGE_PART_POLICY, self.gpb)) self.add_part_policy(PartitionPolicy(EDGE_PART_POLICY, self.gpb))
...@@ -240,20 +279,13 @@ class DistGraphServer(KVServer): ...@@ -240,20 +279,13 @@ class DistGraphServer(KVServer):
self.init_data(name=_get_data_name(name, EDGE_PART_POLICY), self.init_data(name=_get_data_name(name, EDGE_PART_POLICY),
policy_str=EDGE_PART_POLICY, policy_str=EDGE_PART_POLICY,
data_tensor=edge_feats[name]) data_tensor=edge_feats[name])
else:
for name in node_feats:
self.init_data(name=_get_data_name(name, NODE_PART_POLICY),
policy_str=NODE_PART_POLICY)
for name in edge_feats:
self.init_data(name=_get_data_name(name, EDGE_PART_POLICY),
policy_str=EDGE_PART_POLICY)
def start(self): def start(self):
""" Start graph store server. """ Start graph store server.
""" """
# start server # start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb) server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server ' + str(self.server_id)) print('start graph service on server {} for part {}'.format(self.server_id, self.part_id))
start_server(server_id=self.server_id, ip_config=self.ip_config, start_server(server_id=self.server_id, ip_config=self.ip_config,
num_clients=self.num_clients, server_state=server_state) num_clients=self.num_clients, server_state=server_state)
...@@ -314,6 +346,12 @@ class DistGraph: ...@@ -314,6 +346,12 @@ class DistGraph:
self._gpb = get_shared_mem_partition_book(graph_name, self._g) self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None: if self._gpb is None:
self._gpb = gpb self._gpb = gpb
# Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers):
rpc.send_request(server_id, InitGraphRequest(graph_name))
for server_id in range(self._client.num_servers):
rpc.recv_response()
self._client.barrier() self._client.barrier()
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
...@@ -692,3 +730,5 @@ def edge_split(edges, partition_book=None, rank=None, force_even=True): ...@@ -692,3 +730,5 @@ def edge_split(edges, partition_book=None, rank=None, force_even=True):
# Get all edges that belong to the rank. # Get all edges that belong to the rank.
local_eids = partition_book.partid2eids(partition_book.partid) local_eids = partition_book.partid2eids(partition_book.partid)
return _split_local(partition_book, rank, edges, local_eids) return _split_local(partition_book, rank, edges, local_eids)
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
"""Define distributed tensor.""" """Define distributed tensor."""
import os import os
import uuid
from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY, EDGE_PART_POLICY from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY, EDGE_PART_POLICY
from .rpc_client import is_initialized from .rpc_client import is_initialized
...@@ -71,13 +70,15 @@ class DistTensor: ...@@ -71,13 +70,15 @@ class DistTensor:
if init_func is None: if init_func is None:
init_func = _default_init_data init_func = _default_init_data
exist_names = g._client.data_name_list()
# If a user doesn't provide a name, we generate a name ourselves. # If a user doesn't provide a name, we generate a name ourselves.
# We need to generate the name in a deterministic way.
if name is None: if name is None:
assert not persistent, 'We cannot generate anonymous persistent distributed tensors' assert not persistent, 'We cannot generate anonymous persistent distributed tensors'
name = uuid.uuid4().hex[:10] name = 'anonymous-' + str(len(exist_names) + 1)
self._name = _get_data_name(name, part_policy.policy_str) self._name = _get_data_name(name, part_policy.policy_str)
self._persistent = persistent self._persistent = persistent
if self._name not in g._client.data_name_list(): if self._name not in exist_names:
g._client.init_data(self._name, shape, dtype, part_policy, init_func) g._client.init_data(self._name, shape, dtype, part_policy, init_func)
self._owner = True self._owner = True
else: else:
......
...@@ -162,8 +162,7 @@ class InitDataRequest(rpc.Request): ...@@ -162,8 +162,7 @@ class InitDataRequest(rpc.Request):
policy_str=self.policy_str, policy_str=self.policy_str,
data_tensor=data_tensor) data_tensor=data_tensor)
else: else:
kv_store.init_data(name=self.name, kv_store.init_data(name=self.name, policy_str=self.policy_str)
policy_str=self.policy_str)
res = InitDataResponse(INIT_MSG) res = InitDataResponse(INIT_MSG)
return res return res
...@@ -439,18 +438,26 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -439,18 +438,26 @@ class SendMetaToBackupRequest(rpc.Request):
data shape data shape
policy_str : str policy_str : str
partition-policy string, e.g., 'edge' or 'node'. partition-policy string, e.g., 'edge' or 'node'.
pull_handler : callable
The callback function when data is pulled from kvstore.
push_handler : callable
The callback function when data is pushed to kvstore.
""" """
def __init__(self, name, dtype, shape, policy_str): def __init__(self, name, dtype, shape, policy_str, pull_handler, push_handler):
self.name = name self.name = name
self.dtype = dtype self.dtype = dtype
self.shape = shape self.shape = shape
self.policy_str = policy_str self.policy_str = policy_str
self.pull_handler = pull_handler
self.push_handler = push_handler
def __getstate__(self): def __getstate__(self):
return self.name, self.dtype, self.shape, self.policy_str return self.name, self.dtype, self.shape, self.policy_str, self.pull_handler, \
self.push_handler
def __setstate__(self, state): def __setstate__(self, state):
self.name, self.dtype, self.shape, self.policy_str = state self.name, self.dtype, self.shape, self.policy_str, self.pull_handler, \
self.push_handler = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
...@@ -460,6 +467,8 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -460,6 +467,8 @@ class SendMetaToBackupRequest(rpc.Request):
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack) kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str) kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
...@@ -895,6 +904,11 @@ class KVClient(object): ...@@ -895,6 +904,11 @@ class KVClient(object):
"""Get machine ID""" """Get machine ID"""
return self._machine_id return self._machine_id
@property
def num_servers(self):
"""Get the number of servers"""
return self._server_count
def barrier(self): def barrier(self):
"""Barrier for all client nodes. """Barrier for all client nodes.
...@@ -1032,6 +1046,21 @@ class KVClient(object): ...@@ -1032,6 +1046,21 @@ class KVClient(object):
self._full_data_shape[name] = tuple(shape) self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor.
if self._client_id % num_clients_per_part == 0:
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
part_shape, part_policy.policy_str,
self._pull_handlers[name],
self._push_handlers[name])
# send request to all the backup server nodes
for i in range(self._group_count-1):
server_id = self._machine_id * self._group_count + i + 1
rpc.send_request(server_id, request)
# recv response from all the backup server nodes
for _ in range(self._group_count-1):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
self.barrier() self.barrier()
def delete_data(self, name): def delete_data(self, name):
...@@ -1047,7 +1076,7 @@ class KVClient(object): ...@@ -1047,7 +1076,7 @@ class KVClient(object):
self.barrier() self.barrier()
part_policy = self._part_policy[name] part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions() num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() / num_partitions num_clients_per_part = rpc.get_num_client() // num_partitions
if self._client_id % num_clients_per_part == 0: if self._client_id % num_clients_per_part == 0:
# send request to every server nodes # send request to every server nodes
request = DeleteDataRequest(name) request = DeleteDataRequest(name)
...@@ -1108,7 +1137,9 @@ class KVClient(object): ...@@ -1108,7 +1137,9 @@ class KVClient(object):
# Send meta data to backup servers # Send meta data to backup servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
request = SendMetaToBackupRequest(name, dtype, shape, policy_str) request = SendMetaToBackupRequest(name, dtype, shape, policy_str,
self._pull_handlers[name],
self._push_handlers[name])
# send request to all the backup server nodes # send request to all the backup server nodes
for i in range(self._group_count-1): for i in range(self._group_count-1):
server_id = self._machine_id * self._group_count + i + 1 server_id = self._machine_id * self._group_count + i + 1
......
...@@ -138,6 +138,9 @@ def load_partition(conf_file, part_id): ...@@ -138,6 +138,9 @@ def load_partition(conf_file, part_id):
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id" assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id"
gpb, graph_name = load_partition_book(conf_file, part_id, graph) gpb, graph_name = load_partition_book(conf_file, part_id, graph)
nids = F.boolean_mask(graph.ndata[NID], graph.ndata['inner_node'])
partids = gpb.nid2partid(nids)
assert np.all(F.asnumpy(partids == part_id)), 'load a wrong partition'
return graph, node_feats, edge_feats, gpb, graph_name return graph, node_feats, edge_feats, gpb, graph_name
def load_partition_book(conf_file, part_id, graph=None): def load_partition_book(conf_file, part_id, graph=None):
...@@ -162,6 +165,8 @@ def load_partition_book(conf_file, part_id, graph=None): ...@@ -162,6 +165,8 @@ def load_partition_book(conf_file, part_id, graph=None):
with open(conf_file) as conf_f: with open(conf_file) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.' assert 'num_parts' in part_metadata, 'num_parts does not exist.'
assert part_metadata['num_parts'] > part_id, \
'part {} is out of range (#parts: {})'.format(part_id, part_metadata['num_parts'])
num_parts = part_metadata['num_parts'] num_parts = part_metadata['num_parts']
assert 'num_nodes' in part_metadata, "cannot get the number of nodes of the global graph." assert 'num_nodes' in part_metadata, "cannot get the number of nodes of the global graph."
assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph." assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph."
......
...@@ -766,7 +766,9 @@ def send_requests_to_machine(target_and_requests): ...@@ -766,7 +766,9 @@ def send_requests_to_machine(target_and_requests):
service_id = request.service_id service_id = request.service_id
msg_seq = incr_msg_seq() msg_seq = incr_msg_seq()
client_id = get_rank() client_id = get_rank()
server_id = target
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
......
...@@ -42,7 +42,7 @@ class ServerState: ...@@ -42,7 +42,7 @@ class ServerState:
def __init__(self, kv_store, local_g, partition_book): def __init__(self, kv_store, local_g, partition_book):
self._kv_store = kv_store self._kv_store = kv_store
self.graph = local_g self._graph = local_g
self.partition_book = partition_book self.partition_book = partition_book
@property @property
......
...@@ -66,14 +66,14 @@ def emb_init(shape, dtype): ...@@ -66,14 +66,14 @@ def emb_init(shape, dtype):
def rand_init(shape, dtype): def rand_init(shape, dtype):
return F.tensor(np.random.normal(size=shape), F.float32) return F.tensor(np.random.normal(size=shape), F.float32)
def run_client(graph_name, part_id, num_nodes, num_edges): def run_client(graph_name, part_id, num_clients, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
check_dist_graph(g, num_nodes, num_edges) check_dist_graph(g, num_clients, num_nodes, num_edges)
def check_dist_graph(g, num_nodes, num_edges): def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test API # Test API
assert g.number_of_nodes() == num_nodes assert g.number_of_nodes() == num_nodes
assert g.number_of_edges() == num_edges assert g.number_of_edges() == num_edges
...@@ -129,6 +129,7 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -129,6 +129,7 @@ def check_dist_graph(g, num_nodes, num_edges):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
feats = emb(nids) feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr) assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids)) rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest) feats1 = emb(rest)
...@@ -137,7 +138,7 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -137,7 +138,7 @@ def check_dist_graph(g, num_nodes, num_edges):
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book()) policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor(g, (g.number_of_nodes(),), F.float32, grad_sum = dgl.distributed.DistTensor(g, (g.number_of_nodes(),), F.float32,
'emb1_sum', policy) 'emb1_sum', policy)
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1))) assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init) emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init)
...@@ -156,6 +157,7 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -156,6 +157,7 @@ def check_dist_graph(g, num_nodes, num_edges):
optimizer.step() optimizer.step()
with F.no_grad(): with F.no_grad():
feats = emb(nids) feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr) assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids)) rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest) feats1 = emb(rest)
...@@ -188,8 +190,8 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -188,8 +190,8 @@ def check_dist_graph(g, num_nodes, num_edges):
print('end') print('end')
def check_server_client(shared_mem): def check_server_client(shared_mem, num_servers, num_clients):
prepare_dist() prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -203,15 +205,16 @@ def check_server_client(shared_mem): ...@@ -203,15 +205,16 @@ def check_server_client(shared_mem):
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for serv_id in range(1): for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem)) p = ctx.Process(target=run_server, args=(graph_name, serv_id,
num_clients, shared_mem))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
cli_ps = [] cli_ps = []
for cli_id in range(1): for cli_id in range(num_clients):
print('start client', cli_id) print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, 0, num_clients, g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges()))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
...@@ -227,8 +230,10 @@ def check_server_client(shared_mem): ...@@ -227,8 +230,10 @@ def check_server_client(shared_mem):
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client(): def test_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client(True) check_server_client(True, 1, 1)
check_server_client(False) check_server_client(False, 1, 1)
check_server_client(True, 2, 2)
check_server_client(False, 2, 2)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone(): def test_standalone():
...@@ -242,10 +247,10 @@ def test_standalone(): ...@@ -242,10 +247,10 @@ def test_standalone():
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph') partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
dist_g = DistGraph("kv_ip_config.txt", graph_name, dist_g = DistGraph("kv_ip_config.txt", graph_name,
part_config='/tmp/dist_graph/{}.json'.format(graph_name)) part_config='/tmp/dist_graph/{}.json'.format(graph_name))
check_dist_graph(dist_g, g.number_of_nodes(), g.number_of_edges()) check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
def test_split(): def test_split():
prepare_dist() #prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
...@@ -290,7 +295,7 @@ def test_split(): ...@@ -290,7 +295,7 @@ def test_split():
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5))) assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))
def test_split_even(): def test_split_even():
prepare_dist() #prepare_dist(1)
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
...@@ -348,10 +353,10 @@ def test_split_even(): ...@@ -348,10 +353,10 @@ def test_split_even():
assert np.all(all_nodes == F.asnumpy(all_nodes2)) assert np.all(all_nodes == F.asnumpy(all_nodes2))
assert np.all(all_edges == F.asnumpy(all_edges2)) assert np.all(all_edges == F.asnumpy(all_edges2))
def prepare_dist(): def prepare_dist(num_servers):
ip_config = open("kv_ip_config.txt", "w") ip_config = open("kv_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('{} {}\n'.format(ip_addr, num_servers))
ip_config.close() ip_config.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -151,6 +151,7 @@ def start_client(num_clients): ...@@ -151,6 +151,7 @@ def start_client(num_clients):
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt') dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
kvclient.map_shared_data(partition_book=gpb)
assert dgl.distributed.get_num_client() == num_clients assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
...@@ -163,8 +164,6 @@ def start_client(num_clients): ...@@ -163,8 +164,6 @@ def start_client(num_clients):
part_policy=node_policy, part_policy=node_policy,
init_func=init_zero_func) init_func=init_zero_func)
kvclient.map_shared_data(partition_book=gpb)
# Test data_name_list # Test data_name_list
name_list = kvclient.data_name_list() name_list = kvclient.data_name_list()
print(name_list) print(name_list)
...@@ -264,7 +263,6 @@ def start_client(num_clients): ...@@ -264,7 +263,6 @@ def start_client(num_clients):
part_policy=node_policy, part_policy=node_policy,
init_func=init_zero_func) init_func=init_zero_func)
kvclient.register_push_handler('data_3', add_push) kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32) data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
kvclient.barrier() kvclient.barrier()
time.sleep(kvclient.client_id + 1) time.sleep(kvclient.client_id + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment