"src/runtime/vscode:/vscode.git/clone" did not exist on "c454d419cc5e036daaf8ebf73ccb82fa751a5cd0"
Unverified Commit fb248b67 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[RPC] Sampling over RPC (#1616)



* fix

* test1111

* 111

* 111

* fff

* lint

* 111

* lint

* lint

* 111

* fijx

* 111

* fix

* 111

* commit

* 111

* 111

* lint

* fix typo

* fix

* lint

* fix

* 111

* support mxnet

* support mxnet

* lint

* remove print

* fix

* fix test.

* fix test.

* fix test.

* try to fix an undetermistic error.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent ff5b5a4a
......@@ -59,6 +59,8 @@ def cpu():
def tensor(data, dtype=None):
return tf.convert_to_tensor(data, dtype=dtype)
def initialize_context():
tf.zeros(1)
def as_scalar(data):
return data.numpy().asscalar()
......@@ -582,3 +584,5 @@ def _reduce_grad(grad, shape):
def sync():
context = context().context()
context.async_wait()
initialize_context()
\ No newline at end of file
......@@ -9,3 +9,4 @@ from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .sampling import sample_neighbors
......@@ -272,7 +272,6 @@ class DistGraphServer(KVServer):
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
num_clients=num_clients)
self.ip_config = ip_config
# Load graph partition data.
self.client_g, node_feats, edge_feats, self.gpb = load_partition(conf_file, server_id)
if not disable_shared_mem:
......@@ -301,8 +300,8 @@ class DistGraphServer(KVServer):
""" Start graph store server.
"""
# start server
server_state = ServerState(kv_store=self)
start_server(server_id=0, ip_config=self.ip_config,
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
start_server(server_id=self.server_id, ip_config=self.ip_config,
num_clients=self.num_clients, server_state=server_state)
def _default_init_data(shape, dtype):
......@@ -333,7 +332,6 @@ class DistGraph:
def __init__(self, ip_config, graph_name, gpb=None):
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name)
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
......
......@@ -339,6 +339,17 @@ class GraphPartitionBook:
"""
return self._edge_size
@property
def partid(self):
"""Get the current partition id
Return
------
int
The partition id of current machine
"""
return self._part_id
class RangePartitionBook:
"""RangePartitionBook is used to store parition information.
......@@ -593,6 +604,17 @@ class RangePartitionBook:
range_end = self._edge_map[self._partid]
return range_end - range_start
@property
def partid(self):
"""Get the current partition id
Return
------
int
The partition id of current machine
"""
return self._partid
class PartitionPolicy(object):
"""Wrapper for GraphPartitionBook and RangePartitionBook.
......
......@@ -111,6 +111,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
rpc.set_num_server(num_servers)
......
......@@ -78,24 +78,14 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.send_response(client_id, register_res)
# main service loop
while True:
try:
req, client_id = rpc.recv_request()
res = req.process_request(server_state)
if res is not None:
if isinstance(res, list):
for response in res:
target_id, res_data = response
rpc.send_response(target_id, res_data)
elif isinstance(res, str) and res == 'exit':
break # break the loop and exit server
else:
rpc.send_response(client_id, res)
except KeyboardInterrupt:
print("Exit kvserver!")
rpc.finalize_sender()
rpc.finalize_receiver()
except:
print("Error on kvserver!")
rpc.finalize_sender()
rpc.finalize_receiver()
raise
req, client_id = rpc.recv_request()
res = req.process_request(server_state)
if res is not None:
if isinstance(res, list):
for response in res:
target_id, res_data = response
rpc.send_response(target_id, res_data)
elif isinstance(res, str) and res == 'exit':
break # break the loop and exit server
else:
rpc.send_response(client_id, res)
"""Sampling module"""
from .rpc import Request, Response, remote_call_to_machine
from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service
from ..convert import graph
from ..base import NID, EID
from .. import backend as F
__all__ = ['sample_neighbors']
SAMPLING_SERVICE_ID = 6657
class SamplingResponse(Response):
"""Sampling Response"""
def __init__(self, global_src, global_dst, global_eids):
self.global_src = global_src
self.global_dst = global_dst
self.global_eids = global_eids
def __setstate__(self, state):
self.global_src, self.global_dst, self.global_eids = state
def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids
class SamplingRequest(Request):
"""Sampling Request"""
def __init__(self, nodes, fan_out, edge_dir='in', prob=None, replace=False):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
self.replace = replace
self.fan_out = fan_out
def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out = state
def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
local_ids = F.astype(partition_book.nid2localnid(
F.tensor(self.seed_nodes), partition_book.partid), local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(
local_g.edata[EID], sampled_graph.edata[EID])
res = SamplingResponse(global_src, global_dst, global_eids)
return res
def merge_graphs(res_list, num_nodes):
"""Merge request from multiple servers"""
srcs = []
dsts = []
eids = []
for res in res_list:
srcs.append(res.global_src)
dsts.append(res.global_dst)
eids.append(res.global_eids)
src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0)
eid_tensor = F.cat(eids, 0)
g = graph((src_tensor, dst_tensor),
restrict_format='coo', num_nodes=num_nodes)
g.edata[EID] = eid_tensor
return g
def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample neighbors"""
assert edge_dir == 'in'
req_list = []
partition_book = dist_graph.get_partition_book()
partition_id = F.asnumpy(
partition_book.nid2partid(F.tensor(nodes))).tolist()
node_id_per_partition = [[]
for _ in range(partition_book.num_partitions())]
for pid, node in zip(partition_id, nodes):
node_id_per_partition[pid].append(node)
for pid, node_id in enumerate(node_id_per_partition):
if len(node_id) != 0:
req = SamplingRequest(
node_id, fanout, edge_dir=edge_dir, prob=prob, replace=replace)
req_list.append((pid, req))
res_list = remote_call_to_machine(req_list)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
return sampled_graph
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SamplingResponse)
"""Server data"""
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
from ..graph import DGLGraph
from ..transform import as_heterograph
@register_object('server_state.ServerState')
class ServerState(ObjectBase):
# Remove C++ bindings for now, since not used
class ServerState:
"""Data stored in one DGL server.
In a distributed setting, DGL partitions all data associated with the graph
......@@ -35,9 +38,14 @@ class ServerState(ObjectBase):
Total number of nodes
total_num_edges : int
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
"""
def __init__(self, kv_store):
def __init__(self, kv_store, local_g, partition_book):
self._kv_store = kv_store
self.graph = local_g
self.partition_book = partition_book
@property
def kv_store(self):
......@@ -50,33 +58,15 @@ class ServerState(ObjectBase):
@property
def graph(self):
"""Get graph."""
return _CAPI_DGLRPCServerStateGetGraph(self)
"""Get graph data."""
return self._graph
@property
def total_num_nodes(self):
"""Get total number of nodes."""
return _CAPI_DGLRPCServerStateGetTotalNumNodes(self)
@graph.setter
def graph(self, graph):
if isinstance(graph, DGLGraph):
self._graph = as_heterograph(graph)
else:
self._graph = graph
@property
def total_num_edges(self):
"""Get total number of edges."""
return _CAPI_DGLRPCServerStateGetTotalNumEdges(self)
def get_server_state():
"""Get server state data.
If the process is a server, this stores necessary
server-side data. Otherwise, the process is a client and it stores a cache
of the server co-located with the client (if available). When the client
invokes a RPC to the co-located server, it can thus perform computation
locally without an actual remote call.
Returns
-------
ServerState
Server state data
"""
return _CAPI_DGLRPCGetServerState()
_init_api("dgl.distributed.server_state")
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed.sampling import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir):
import dgl
g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling",
tmpdir / 'test_sampling.json', disable_shared_mem=True)
g.start()
def start_client(rank, tmpdir):
import dgl
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
return sampled_graph
def check_rpc_sampling(tmpdir):
num_server = 2
ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
g.readonly()
print(g.idtype)
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir)
print("Done sampling")
for p in pserver_list:
p.join()
src, dst = sampled_graph.edges()
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname))
def check_rpc_sampling_shuffle(tmpdir):
num_server = 2
ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
g.readonly()
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir)
print("Done sampling")
for p in pserver_list:
p.join()
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64)
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64)
for i in range(num_server):
part, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling_shuffle():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling_shuffle(Path(tmpdirname))
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname))
check_rpc_sampling_shuffle(Path(tmpdirname))
......@@ -119,7 +119,7 @@ def start_server():
kvserver.init_data('data_0_2', 'node', data_0_2)
kvserver.init_data('data_0_3', 'node', data_0_3)
# start server
server_state = dgl.distributed.ServerState(kv_store=kvserver)
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=0,
ip_config='kv_ip_config.txt',
num_clients=1,
......
......@@ -13,8 +13,7 @@ import pickle
import random
def create_random_graph(n):
random.seed(100)
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)
......
......@@ -108,7 +108,7 @@ class HelloRequest(dgl.distributed.Request):
return res
def start_server():
server_state = dgl.distributed.ServerState(None)
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.start_server(server_id=0,
ip_config='rpc_ip_config.txt',
......
import socket
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment