Unverified Commit 45e1333e authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Optimize distributed sampling (#1700)



* optimize sampling neighbors from the local partition.

* fix minior bugs.

* fix lint

* overlap local sampling and remote sampling.

* fix.

* fix lint.

* fix a potential bug.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent 4b96addc
...@@ -16,6 +16,7 @@ from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, ...@@ -16,6 +16,7 @@ from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path,
from .rpc_client import connect_to_server from .rpc_client import connect_to_server
from .server_state import ServerState from .server_state import ServerState
from .rpc_server import start_server from .rpc_server import start_server
from ..transform import as_heterograph
def _get_graph_path(graph_name): def _get_graph_path(graph_name):
return "/" + graph_name return "/" + graph_name
...@@ -332,7 +333,11 @@ class DistGraph: ...@@ -332,7 +333,11 @@ class DistGraph:
def __init__(self, ip_config, graph_name, gpb=None): def __init__(self, ip_config, graph_name, gpb=None):
connect_to_server(ip_config=ip_config) connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config) self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name) g = _get_graph_from_shared_mem(graph_name)
if g is not None:
self._g = as_heterograph(g)
else:
self._g = None
self._gpb = get_shared_mem_partition_book(graph_name, self._g) self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None: if self._gpb is None:
self._gpb = gpb self._gpb = gpb
...@@ -418,6 +423,21 @@ class DistGraph: ...@@ -418,6 +423,21 @@ class DistGraph:
# TODO(zhengda) # TODO(zhengda)
raise NotImplementedError("get_node_embeddings isn't supported yet") raise NotImplementedError("get_node_embeddings isn't supported yet")
@property
def local_partition(self):
''' Return the local partition on the client
DistGraph provides a global view of the distributed graph. Internally,
it may contains a partition of the graph if it is co-located with
the server. If there is no co-location, this returns None.
Returns
-------
DGLHeterograph
The local partition
'''
return self._g
@property @property
def ndata(self): def ndata(self):
"""Return the data view of all the nodes. """Return the data view of all the nodes.
......
...@@ -238,7 +238,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -238,7 +238,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
g.edata[EID] = F.arange(0, g.number_of_edges()) g.edata[EID] = F.arange(0, g.number_of_edges())
g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int64, F.cpu()) g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int64, F.cpu())
g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int64, F.cpu()) g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int64, F.cpu())
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes()) if reshuffle:
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes())
g.edata['orig_id'] = F.arange(0, g.number_of_edges())
elif part_method == 'metis': elif part_method == 'metis':
node_parts = metis_partition_assignment(g, num_parts) node_parts = metis_partition_assignment(g, num_parts)
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle) client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
......
...@@ -3,6 +3,7 @@ server and clients.""" ...@@ -3,6 +3,7 @@ server and clients."""
import abc import abc
import pickle import pickle
import random import random
import numpy as np
from .._ffi.object import register_object, ObjectBase from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api from .._ffi.function import _init_api
...@@ -731,15 +732,10 @@ def remote_call(target_and_requests, timeout=0): ...@@ -731,15 +732,10 @@ def remote_call(target_and_requests, timeout=0):
all_res[msgseq2pos[msg.msg_seq]] = res all_res[msgseq2pos[msg.msg_seq]] = res
return all_res return all_res
def remote_call_to_machine(target_and_requests, timeout=0): def send_requests_to_machine(target_and_requests):
"""Invoke registered services on remote machine """ Send requests to the remote machines.
(which will ramdom select a server to process the request) and collect responses.
The operation is blocking -- it returns when it receives all responses This operation isn't block. It returns immediately once it sends all requests.
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters Parameters
---------- ----------
...@@ -750,19 +746,10 @@ def remote_call_to_machine(target_and_requests, timeout=0): ...@@ -750,19 +746,10 @@ def remote_call_to_machine(target_and_requests, timeout=0):
Returns Returns
------- -------
list[Response] msgseq2pos : dict
Responses for each target-request pair. If the request does not have map the message sequence number to its position in the input list.
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
""" """
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {} msgseq2pos = {}
num_res = 0
myrank = get_rank()
for pos, (target, request) in enumerate(target_and_requests): for pos, (target, request) in enumerate(target_and_requests):
# send request # send request
service_id = request.service_id service_id = request.service_id
...@@ -775,8 +762,35 @@ def remote_call_to_machine(target_and_requests, timeout=0): ...@@ -775,8 +762,35 @@ def remote_call_to_machine(target_and_requests, timeout=0):
# check if has response # check if has response
res_cls = get_service_property(service_id)[1] res_cls = get_service_property(service_id)[1]
if res_cls is not None: if res_cls is not None:
num_res += 1
msgseq2pos[msg_seq] = pos msgseq2pos[msg_seq] = pos
return msgseq2pos
def recv_responses(msgseq2pos, timeout=0):
""" Receive responses
It returns the responses in the same order as the requests. The order of requests
are stored in msgseq2pos.
The operation is blocking -- it returns when it receives all responses
or it times out.
Parameters
----------
msgseq2pos : dict
map the message sequence number to its position in the input list.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
"""
myrank = get_rank()
size = np.max(list(msgseq2pos.values())) + 1
all_res = [None] * size
num_res = len(msgseq2pos)
while num_res != 0: while num_res != 0:
# recv response # recv response
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
...@@ -793,6 +807,37 @@ def remote_call_to_machine(target_and_requests, timeout=0): ...@@ -793,6 +807,37 @@ def remote_call_to_machine(target_and_requests, timeout=0):
all_res[msgseq2pos[msg.msg_seq]] = res all_res[msgseq2pos[msg.msg_seq]] = res
return all_res return all_res
def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters
----------
target_and_requests : list[(int, Request)]
A list of requests and the machine they should be sent to.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout)
def send_rpc_message(msg, target): def send_rpc_message(msg, target):
"""Send one message to the target server. """Send one message to the target server.
......
"""Sampling module""" """Sampling module"""
from .rpc import Request, Response, remote_call_to_machine from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service from . import register_service
from ..convert import graph from ..convert import graph
...@@ -27,6 +29,26 @@ class SamplingResponse(Response): ...@@ -27,6 +29,26 @@ class SamplingResponse(Response):
return self.global_src, self.global_dst, self.global_eids return self.global_src, self.global_dst, self.global_eids
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
""" Sample from local partition.
The input nodes use global Ids. We need to map the global node Ids to local node Ids,
perform sampling and map the sampled results to the global Ids space again.
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge Ids.
"""
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, fan_out, edge_dir, prob, replace)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
class SamplingRequest(Request): class SamplingRequest(Request):
"""Sampling Request""" """Sampling Request"""
...@@ -46,52 +68,112 @@ class SamplingRequest(Request): ...@@ -46,52 +68,112 @@ class SamplingRequest(Request):
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
local_ids = F.astype(partition_book.nid2localnid( global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book,
self.seed_nodes, partition_book.partid), local_g.idtype) self.seed_nodes,
# local_ids = self.seed_nodes self.fan_out, self.edge_dir,
sampled_graph = local_sample_neighbors( self.prob, self.replace)
local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace) return SamplingResponse(global_src, global_dst, global_eids)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(
local_g.edata[EID], sampled_graph.edata[EID])
res = SamplingResponse(global_src, global_dst, global_eids)
return res
def merge_graphs(res_list, num_nodes): def merge_graphs(res_list, num_nodes):
"""Merge request from multiple servers""" """Merge request from multiple servers"""
srcs = [] if len(res_list) > 1:
dsts = [] srcs = []
eids = [] dsts = []
for res in res_list: eids = []
srcs.append(res.global_src) for res in res_list:
dsts.append(res.global_dst) srcs.append(res.global_src)
eids.append(res.global_eids) dsts.append(res.global_dst)
src_tensor = F.cat(srcs, 0) eids.append(res.global_eids)
dst_tensor = F.cat(dsts, 0) src_tensor = F.cat(srcs, 0)
eid_tensor = F.cat(eids, 0) dst_tensor = F.cat(dsts, 0)
eid_tensor = F.cat(eids, 0)
else:
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor), g = graph((src_tensor, dst_tensor),
restrict_format='coo', num_nodes=num_nodes) restrict_format='coo', num_nodes=num_nodes)
g.edata[EID] = eid_tensor g.edata[EID] = eid_tensor
return g return g
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False): def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample neighbors""" """Sample from the neighbors of the given nodes from a distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
assert edge_dir == 'in' assert edge_dir == 'in'
req_list = [] req_list = []
partition_book = dist_graph.get_partition_book() partition_book = dist_graph.get_partition_book()
np_nodes = toindex(nodes).tousertensor() nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(np_nodes) partition_id = partition_book.nid2partid(nodes)
local_nids = None
for pid in range(partition_book.num_partitions()): for pid in range(partition_book.num_partitions()):
node_id = F.boolean_mask(np_nodes, partition_id == pid) node_id = F.boolean_mask(nodes, partition_id == pid)
if len(node_id) != 0: # We optimize the sampling on a local partition if the server and the client
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and dist_graph.local_partition is not None:
assert local_nids is None
local_nids = node_id
elif len(node_id) != 0:
req = SamplingRequest(node_id, fanout, edge_dir=edge_dir, req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
prob=prob, replace=replace) prob=prob, replace=replace)
req_list.append((pid, req)) req_list.append((pid, req))
res_list = remote_call_to_machine(req_list)
# send requests to the remote machine.
msgseq2pos = None
if len(req_list) > 0:
msgseq2pos = send_requests_to_machine(req_list)
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = _sample_neighbors(dist_graph.local_partition, partition_book,
local_nids, fanout, edge_dir, prob, replace)
res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines.
if msgseq2pos is not None:
results = recv_responses(msgseq2pos)
res_list.extend(results)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes()) sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
return sampled_graph return sampled_graph
......
...@@ -15,16 +15,18 @@ from pathlib import Path ...@@ -15,16 +15,18 @@ from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir): def start_server(rank, tmpdir, disable_shared_mem):
import dgl import dgl
g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling", g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling",
tmpdir / 'test_sampling.json', disable_shared_mem=True) tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start() g.start()
def start_client(rank, tmpdir): def start_client(rank, tmpdir, disable_shared_mem):
import dgl import dgl
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank) gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb) dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers() dgl.distributed.shutdown_servers()
...@@ -32,8 +34,7 @@ def start_client(rank, tmpdir): ...@@ -32,8 +34,7 @@ def start_client(rank, tmpdir):
return sampled_graph return sampled_graph
def check_rpc_sampling(tmpdir): def check_rpc_sampling(tmpdir, num_server):
num_server = 2
ip_config = open("rpc_sampling_ip_config.txt", "w") ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{} 1\n'.format(get_local_usable_addr()))
...@@ -51,13 +52,13 @@ def check_rpc_sampling(tmpdir): ...@@ -51,13 +52,13 @@ def check_rpc_sampling(tmpdir):
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for i in range(num_server): for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir)) p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3) time.sleep(3)
sampled_graph = start_client(0, tmpdir) sampled_graph = start_client(0, tmpdir, num_server > 1)
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -75,10 +76,9 @@ def test_rpc_sampling(): ...@@ -75,10 +76,9 @@ def test_rpc_sampling():
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling" tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname)) check_rpc_sampling(Path(tmpdirname), 2)
def check_rpc_sampling_shuffle(tmpdir): def check_rpc_sampling_shuffle(tmpdir, num_server):
num_server = 2
ip_config = open("rpc_sampling_ip_config.txt", "w") ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{} 1\n'.format(get_local_usable_addr()))
...@@ -95,13 +95,13 @@ def check_rpc_sampling_shuffle(tmpdir): ...@@ -95,13 +95,13 @@ def check_rpc_sampling_shuffle(tmpdir):
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for i in range(num_server): for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir)) p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3) time.sleep(3)
sampled_graph = start_client(0, tmpdir) sampled_graph = start_client(0, tmpdir, num_server > 1)
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -129,11 +129,14 @@ def test_rpc_sampling_shuffle(): ...@@ -129,11 +129,14 @@ def test_rpc_sampling_shuffle():
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling" tmpdirname = "/tmp/sampling"
check_rpc_sampling_shuffle(Path(tmpdirname)) check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
if __name__ == "__main__": if __name__ == "__main__":
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling" tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname)) check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname)) check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment