Unverified Commit 167216af authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] add in_subgraph on DistGraph. (#1755)

* add in_subgraph on DistGraph.

* check in more.

* fix test.

* add comments.

* fix test.

* update test.

* update.

* rename.

* update comment

* fix test
parent bdc1e649
......@@ -10,4 +10,4 @@ from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .sampling import sample_neighbors
from .graph_services import sample_neighbors, in_subgraph
"""Sampling module"""
"""A set of graph services of getting subgraphs from DistGraph"""
from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph
from . import register_service
from ..convert import graph
from ..base import NID, EID
from ..utils import toindex
from .. import backend as F
__all__ = ['sample_neighbors']
__all__ = ['sample_neighbors', 'in_subgraph']
SAMPLING_SERVICE_ID = 6657
INSUBGRAPH_SERVICE_ID = 6658
class SamplingResponse(Response):
"""Sampling Response"""
class SubgraphResponse(Response):
"""The response for sampling and in_subgraph"""
def __init__(self, global_src, global_dst, global_eids):
self.global_src = global_src
......@@ -49,6 +50,25 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
return global_src, global_dst, global_eids
def _in_subgraph(local_g, partition_book, seed_nodes):
""" Get in subgraph from local partition.
The input nodes use global Ids. We need to map the global node Ids to local node Ids,
get in-subgraph and map the sampled results to the global Ids space again.
The results are stored in three vectors that store source nodes, destination nodes
and edge Ids.
"""
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_in_subgraph(local_g, local_ids)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
class SamplingRequest(Request):
"""Sampling Request"""
......@@ -72,7 +92,27 @@ class SamplingRequest(Request):
self.seed_nodes,
self.fan_out, self.edge_dir,
self.prob, self.replace)
return SamplingResponse(global_src, global_dst, global_eids)
return SubgraphResponse(global_src, global_dst, global_eids)
class InSubgraphRequest(Request):
"""InSubgraph Request"""
def __init__(self, nodes):
self.seed_nodes = nodes
def __setstate__(self, state):
self.seed_nodes = state
def __getstate__(self):
return self.seed_nodes
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst, global_eids = _in_subgraph(local_g, partition_book,
self.seed_nodes)
return SubgraphResponse(global_src, global_dst, global_eids)
def merge_graphs(res_list, num_nodes):
......@@ -99,47 +139,33 @@ def merge_graphs(res_list, num_nodes):
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
def _distributed_access(g, nodes, issue_remote_req, local_access):
'''A routine that fetches local neighborhood of nodes from the distributed graph.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
The local neighborhood of some nodes are stored in the local machine and the other
nodes have their neighborhood on remote machines. This code will issue remote
access requests first before fetching data from the local machine. In the end,
we combine the data from the local machine and remote machines.
In this way, we can hide the latency of accessing data on remote machines.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
The distributed graph
nodes : tensor
The nodes whose neighborhood are to be fetched.
issue_remote_req : callable
The function that issues requests to access remote data.
local_access : callable
The function that reads data on the local machine.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
assert edge_dir == 'in'
The subgraph that contains the neighborhoods of all input nodes.
'''
req_list = []
partition_book = dist_graph.get_partition_book()
partition_book = g.get_partition_book()
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
......@@ -149,12 +175,11 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and dist_graph.local_partition is not None:
if pid == partition_book.partid and g.local_partition is not None:
assert local_nids is None
local_nids = node_id
elif len(node_id) != 0:
req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
req = issue_remote_req(node_id)
req_list.append((pid, req))
# send requests to the remote machine.
......@@ -165,8 +190,7 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = _sample_neighbors(dist_graph.local_partition, partition_book,
local_nids, fanout, edge_dir, prob, replace)
src, dst, eids = local_access(g.local_partition, partition_book, local_nids)
res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines.
......@@ -174,8 +198,80 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
results = recv_responses(msgseq2pos)
res_list.extend(results)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
For now, we only support the input graph with one node type and one edge type.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor
Node ids to sample neighbors from.
fanout : int
The number of sampled neighbors for each node.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace)
return _distributed_access(g, nodes, issue_remote_req, local_access)
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SamplingResponse)
def in_subgraph(g, nodes):
"""Extract the subgraph containing only the in edges of the given nodes.
The subgraph keeps the same type schema and the cardinality of the original one.
Node/edge features are not preserved. The original IDs
the extracted edges are stored as the `dgl.EID` feature in the returned graph.
For now, we only support the input graph with one node type and one edge type.
Parameters
----------
g : DistGraph
The distributed graph structure.
nodes : tensor
Node ids to sample neighbors from.
Returns
-------
DGLHeteroGraph
The subgraph.
"""
def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids)
def local_access(local_g, partition_book, local_nids):
return _in_subgraph(local_g, partition_book, local_nids)
return _distributed_access(g, nodes, issue_remote_req, local_access)
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
......@@ -2,7 +2,7 @@ import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed.sampling import sample_neighbors
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
......@@ -15,19 +15,19 @@ from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem):
def start_server(rank, tmpdir, disable_shared_mem, graph_name):
import dgl
g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling",
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, graph_name,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
g.start()
def start_client(rank, tmpdir, disable_shared_mem):
def start_sample_client(rank, tmpdir, disable_shared_mem):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
......@@ -35,7 +35,7 @@ def start_client(rank, tmpdir, disable_shared_mem):
def check_rpc_sampling(tmpdir, num_server):
ip_config = open("rpc_sampling_ip_config.txt", "w")
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
......@@ -52,13 +52,13 @@ def check_rpc_sampling(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
p.join()
......@@ -75,11 +75,10 @@ def check_rpc_sampling(tmpdir, num_server):
def test_rpc_sampling():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling(Path(tmpdirname), 2)
def check_rpc_sampling_shuffle(tmpdir, num_server):
ip_config = open("rpc_sampling_ip_config.txt", "w")
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
......@@ -95,13 +94,13 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
p.join()
......@@ -128,14 +127,70 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
def test_rpc_sampling_shuffle():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
return sampled_graph
def check_rpc_in_subgraph(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
g.readonly()
num_parts = num_server
partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
num_hops=1, part_method='metis', reshuffle=False)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph'))
p.start()
time.sleep(1)
pserver_list.append(p)
nodes = [0, 10, 99, 66, 1024, 2008]
time.sleep(3)
sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
for p in pserver_list:
p.join()
src, dst = sampled_graph.edges()
g = dgl.as_heterograph(g)
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
subg1 = dgl.in_subgraph(g, nodes)
src1, dst1 = subg1.edges()
assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
eids = g.edge_ids(src, dst)
assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_in_subgraph(Path(tmpdirname), 2)
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "/tmp/sampling"
check_rpc_in_subgraph(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment