Unverified Commit 45e1333e authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Optimize distributed sampling (#1700)



* optimize sampling neighbors from the local partition.

* fix minior bugs.

* fix lint

* overlap local sampling and remote sampling.

* fix.

* fix lint.

* fix a potential bug.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent 4b96addc
......@@ -16,6 +16,7 @@ from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path,
from .rpc_client import connect_to_server
from .server_state import ServerState
from .rpc_server import start_server
from ..transform import as_heterograph
def _get_graph_path(graph_name):
return "/" + graph_name
......@@ -332,7 +333,11 @@ class DistGraph:
def __init__(self, ip_config, graph_name, gpb=None):
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name)
g = _get_graph_from_shared_mem(graph_name)
if g is not None:
self._g = as_heterograph(g)
else:
self._g = None
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
......@@ -418,6 +423,21 @@ class DistGraph:
# TODO(zhengda)
raise NotImplementedError("get_node_embeddings isn't supported yet")
@property
def local_partition(self):
''' Return the local partition on the client
DistGraph provides a global view of the distributed graph. Internally,
it may contains a partition of the graph if it is co-located with
the server. If there is no co-location, this returns None.
Returns
-------
DGLHeterograph
The local partition
'''
return self._g
@property
def ndata(self):
"""Return the data view of all the nodes.
......
......@@ -238,7 +238,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
g.edata[EID] = F.arange(0, g.number_of_edges())
g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int64, F.cpu())
g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int64, F.cpu())
if reshuffle:
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes())
g.edata['orig_id'] = F.arange(0, g.number_of_edges())
elif part_method == 'metis':
node_parts = metis_partition_assignment(g, num_parts)
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
......
......@@ -3,6 +3,7 @@ server and clients."""
import abc
import pickle
import random
import numpy as np
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
......@@ -731,15 +732,10 @@ def remote_call(target_and_requests, timeout=0):
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
def send_requests_to_machine(target_and_requests):
""" Send requests to the remote machines.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
This operation isn't block. It returns immediately once it sends all requests.
Parameters
----------
......@@ -750,19 +746,10 @@ def remote_call_to_machine(target_and_requests, timeout=0):
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
msgseq2pos : dict
map the message sequence number to its position in the input list.
"""
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {}
num_res = 0
myrank = get_rank()
for pos, (target, request) in enumerate(target_and_requests):
# send request
service_id = request.service_id
......@@ -775,8 +762,35 @@ def remote_call_to_machine(target_and_requests, timeout=0):
# check if has response
res_cls = get_service_property(service_id)[1]
if res_cls is not None:
num_res += 1
msgseq2pos[msg_seq] = pos
return msgseq2pos
def recv_responses(msgseq2pos, timeout=0):
""" Receive responses
It returns the responses in the same order as the requests. The order of requests
are stored in msgseq2pos.
The operation is blocking -- it returns when it receives all responses
or it times out.
Parameters
----------
msgseq2pos : dict
map the message sequence number to its position in the input list.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
"""
myrank = get_rank()
size = np.max(list(msgseq2pos.values())) + 1
all_res = [None] * size
num_res = len(msgseq2pos)
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
......@@ -793,6 +807,37 @@ def remote_call_to_machine(target_and_requests, timeout=0):
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters
----------
target_and_requests : list[(int, Request)]
A list of requests and the machine they should be sent to.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout)
def send_rpc_message(msg, target):
"""Send one message to the target server.
......
"""Sampling module"""
from .rpc import Request, Response, remote_call_to_machine
from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service
from ..convert import graph
......@@ -27,6 +29,26 @@ class SamplingResponse(Response):
return self.global_src, self.global_dst, self.global_eids
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
""" Sample from local partition.
The input nodes use global Ids. We need to map the global node Ids to local node Ids,
perform sampling and map the sampled results to the global Ids space again.
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge Ids.
"""
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, fan_out, edge_dir, prob, replace)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
class SamplingRequest(Request):
"""Sampling Request"""
......@@ -46,22 +68,16 @@ class SamplingRequest(Request):
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
local_ids = F.astype(partition_book.nid2localnid(
self.seed_nodes, partition_book.partid), local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(
local_g.edata[EID], sampled_graph.edata[EID])
res = SamplingResponse(global_src, global_dst, global_eids)
return res
global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book,
self.seed_nodes,
self.fan_out, self.edge_dir,
self.prob, self.replace)
return SamplingResponse(global_src, global_dst, global_eids)
def merge_graphs(res_list, num_nodes):
"""Merge request from multiple servers"""
if len(res_list) > 1:
srcs = []
dsts = []
eids = []
......@@ -72,26 +88,92 @@ def merge_graphs(res_list, num_nodes):
src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0)
eid_tensor = F.cat(eids, 0)
else:
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor),
restrict_format='coo', num_nodes=num_nodes)
g.edata[EID] = eid_tensor
return g
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample neighbors"""
"""Sample from the neighbors of the given nodes from a distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
assert edge_dir == 'in'
req_list = []
partition_book = dist_graph.get_partition_book()
np_nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(np_nodes)
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
for pid in range(partition_book.num_partitions()):
node_id = F.boolean_mask(np_nodes, partition_id == pid)
if len(node_id) != 0:
node_id = F.boolean_mask(nodes, partition_id == pid)
# We optimize the sampling on a local partition if the server and the client
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and dist_graph.local_partition is not None:
assert local_nids is None
local_nids = node_id
elif len(node_id) != 0:
req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
req_list.append((pid, req))
res_list = remote_call_to_machine(req_list)
# send requests to the remote machine.
msgseq2pos = None
if len(req_list) > 0:
msgseq2pos = send_requests_to_machine(req_list)
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = _sample_neighbors(dist_graph.local_partition, partition_book,
local_nids, fanout, edge_dir, prob, replace)
res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines.
if msgseq2pos is not None:
results = recv_responses(msgseq2pos)
res_list.extend(results)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
return sampled_graph
......
......@@ -15,15 +15,17 @@ from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir):
def start_server(rank, tmpdir, disable_shared_mem):
import dgl
g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling",
tmpdir / 'test_sampling.json', disable_shared_mem=True)
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start()
def start_client(rank, tmpdir):
def start_client(rank, tmpdir, disable_shared_mem):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
......@@ -32,8 +34,7 @@ def start_client(rank, tmpdir):
return sampled_graph
def check_rpc_sampling(tmpdir):
num_server = 2
def check_rpc_sampling(tmpdir, num_server):
ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
......@@ -51,13 +52,13 @@ def check_rpc_sampling(tmpdir):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir)
sampled_graph = start_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
p.join()
......@@ -75,10 +76,9 @@ def test_rpc_sampling():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname))
check_rpc_sampling(Path(tmpdirname), 2)
def check_rpc_sampling_shuffle(tmpdir):
num_server = 2
def check_rpc_sampling_shuffle(tmpdir, num_server):
ip_config = open("rpc_sampling_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
......@@ -95,13 +95,13 @@ def check_rpc_sampling_shuffle(tmpdir):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir)
sampled_graph = start_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
p.join()
......@@ -129,11 +129,14 @@ def test_rpc_sampling_shuffle():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling_shuffle(Path(tmpdirname))
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname))
check_rpc_sampling_shuffle(Path(tmpdirname))
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment