Unverified Commit 14e4e1b0 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] enable to specify sort_etype for sample_etype_neighbours (#4212)



* [Dist] enable to specify sort_etype for sample_etype_neighbours

* fix lint

* pass argument instead of env

* fix lint and doc string

* refine args

* remove unnecessary lines

* debug only

* debug add sort time log

* change interface

* fix typo
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent c65d6fa5
......@@ -9,7 +9,7 @@ import numpy as np
from ..heterograph import DGLHeteroGraph
from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph
from ..transforms import compact_graphs
from ..transforms import compact_graphs, sort_csr_by_tag, sort_csc_by_tag
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
......@@ -350,6 +350,14 @@ class DistGraphServer(KVServer):
# Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
# Sort underlying matrix beforehand to avoid runtime overhead during sampling.
if len(etypes) > 1:
if 'csr' in graph_format:
self.client_g = sort_csr_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if 'csc' in graph_format:
self.client_g = sort_csc_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
......@@ -1255,14 +1263,14 @@ class DistGraph:
self._client.barrier()
def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None,
exclude_edges=None, replace=False,
exclude_edges=None, replace=False, etype_sorted=True,
output_device=None):
# pylint: disable=unused-argument
"""Sample neighbors from a distributed graph."""
# Currently prob, exclude_edges, output_device, and edge_dir are ignored.
if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors(
self, seed_nodes, ETYPE, fanout, replace=replace)
self, seed_nodes, ETYPE, fanout, replace=replace, etype_sorted=etype_sorted)
else:
frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace)
......
......@@ -164,21 +164,23 @@ class SamplingRequest(Request):
class SamplingRequestEtype(Request):
"""Sampling Request"""
def __init__(self, nodes, etype_field, fan_out, edge_dir='in', prob=None, replace=False):
def __init__(self, nodes, etype_field, fan_out, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
self.replace = replace
self.fan_out = fan_out
self.etype_field = etype_field
self.etype_sorted = etype_sorted
def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field = state
self.fan_out, self.etype_field, self.etype_sorted = state
def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field
self.fan_out, self.etype_field, self.etype_sorted
def process_request(self, server_state):
local_g = server_state.graph
......@@ -190,7 +192,8 @@ class SamplingRequestEtype(Request):
self.fan_out,
self.edge_dir,
self.prob,
self.replace)
self.replace,
self.etype_sorted)
return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request):
......@@ -418,7 +421,8 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False):
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -471,6 +475,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
Returns
-------
......@@ -496,10 +502,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids):
return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
prob=prob, replace=replace, etype_sorted=etype_sorted)
def local_access(local_g, partition_book, local_nids):
return _sample_etype_neighbors(local_g, partition_book, local_nids,
etype_field, fanout, edge_dir, prob, replace)
etype_field, fanout, edge_dir, prob, replace,
etype_sorted=etype_sorted)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
......
......@@ -335,7 +335,8 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
return block, gpb
def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
nodes={'n3': [0, 10, 99, 66, 124, 208]}):
nodes={'n3': [0, 10, 99, 66, 124, 208]},
etype_sorted=False):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
......@@ -358,7 +359,7 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout)
sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout, etype_sorted=etype_sorted)
block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e:
......@@ -461,7 +462,7 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes)
def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=False):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(dense=True)
......@@ -474,14 +475,15 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling', ['csc', 'coo']))
p.start()
time.sleep(1)
pserver_list.append(p)
fanout = 3
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': [0, 10, 99, 66, 124, 208]})
nodes={'n3': [0, 10, 99, 66, 124, 208]},
etype_sorted=etype_sorted)
print("Done sampling")
for p in pserver_list:
p.join()
......@@ -832,6 +834,7 @@ def test_rpc_sampling_shuffle(num_server):
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server, etype_sorted=True)
check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment