"...text-generation-inference.git" did not exist on "6f88bd9390a3edce1dfec025a526d6c2849effa4"
Unverified Commit 14e4e1b0 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] enable to specify sort_etype for sample_etype_neighbours (#4212)



* [Dist] enable to specify sort_etype for sample_etype_neighbours

* fix lint

* pass argument instead of env

* fix lint and doc string

* refine args

* remove unnecessary lines

* debug only

* debug add sort time log

* change interface

* fix typo
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent c65d6fa5
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..convert import heterograph as dgl_heterograph from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..transforms import compact_graphs from ..transforms import compact_graphs, sort_csr_by_tag, sort_csc_by_tag
from .. import heterograph_index from .. import heterograph_index
from .. import backend as F from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
...@@ -350,6 +350,14 @@ class DistGraphServer(KVServer): ...@@ -350,6 +350,14 @@ class DistGraphServer(KVServer):
# Create the graph formats specified the users. # Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format) self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_() self.client_g.create_formats_()
# Sort underlying matrix beforehand to avoid runtime overhead during sampling.
if len(etypes) > 1:
if 'csr' in graph_format:
self.client_g = sort_csr_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if 'csc' in graph_format:
self.client_g = sort_csc_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format) self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
...@@ -1255,14 +1263,14 @@ class DistGraph: ...@@ -1255,14 +1263,14 @@ class DistGraph:
self._client.barrier() self._client.barrier()
def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None,
exclude_edges=None, replace=False, exclude_edges=None, replace=False, etype_sorted=True,
output_device=None): output_device=None):
# pylint: disable=unused-argument # pylint: disable=unused-argument
"""Sample neighbors from a distributed graph.""" """Sample neighbors from a distributed graph."""
# Currently prob, exclude_edges, output_device, and edge_dir are ignored. # Currently prob, exclude_edges, output_device, and edge_dir are ignored.
if len(self.etypes) > 1: if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors( frontier = graph_services.sample_etype_neighbors(
self, seed_nodes, ETYPE, fanout, replace=replace) self, seed_nodes, ETYPE, fanout, replace=replace, etype_sorted=etype_sorted)
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace) self, seed_nodes, fanout, replace=replace)
......
...@@ -164,21 +164,23 @@ class SamplingRequest(Request): ...@@ -164,21 +164,23 @@ class SamplingRequest(Request):
class SamplingRequestEtype(Request): class SamplingRequestEtype(Request):
"""Sampling Request""" """Sampling Request"""
def __init__(self, nodes, etype_field, fan_out, edge_dir='in', prob=None, replace=False): def __init__(self, nodes, etype_field, fan_out, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob self.prob = prob
self.replace = replace self.replace = replace
self.fan_out = fan_out self.fan_out = fan_out
self.etype_field = etype_field self.etype_field = etype_field
self.etype_sorted = etype_sorted
def __setstate__(self, state): def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, \ self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field = state self.fan_out, self.etype_field, self.etype_sorted = state
def __getstate__(self): def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, \ return self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field self.fan_out, self.etype_field, self.etype_sorted
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
...@@ -190,7 +192,8 @@ class SamplingRequestEtype(Request): ...@@ -190,7 +192,8 @@ class SamplingRequestEtype(Request):
self.fan_out, self.fan_out,
self.edge_dir, self.edge_dir,
self.prob, self.prob,
self.replace) self.replace,
self.etype_sorted)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request): class EdgesRequest(Request):
...@@ -418,7 +421,8 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -418,7 +421,8 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
hg.edges[etype].data[EID] = edge_ids[etype] hg.edges[etype].data[EID] = edge_ids[etype]
return hg return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False): def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -471,6 +475,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -471,6 +475,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
For sampling without replacement, if fanout > the number of neighbors, all the For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected. neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
Returns Returns
------- -------
...@@ -496,10 +502,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -496,10 +502,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir, return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir,
prob=prob, replace=replace) prob=prob, replace=replace, etype_sorted=etype_sorted)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _sample_etype_neighbors(local_g, partition_book, local_nids, return _sample_etype_neighbors(local_g, partition_book, local_nids,
etype_field, fanout, edge_dir, prob, replace) etype_field, fanout, edge_dir, prob, replace,
etype_sorted=etype_sorted)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
......
...@@ -335,7 +335,8 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -335,7 +335,8 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
return block, gpb return block, gpb
def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3, def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
nodes={'n3': [0, 10, 99, 66, 124, 208]}): nodes={'n3': [0, 10, 99, 66, 124, 208]},
etype_sorted=False):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
...@@ -358,7 +359,7 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3, ...@@ -358,7 +359,7 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
try: try:
sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout) sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout, etype_sorted=etype_sorted)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
...@@ -461,7 +462,7 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): ...@@ -461,7 +462,7 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
assert block.number_of_edges() == 0 assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes) assert len(block.etypes) == len(g.etypes)
def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server): def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=False):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(dense=True) g = create_random_hetero(dense=True)
...@@ -474,14 +475,15 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server): ...@@ -474,14 +475,15 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for i in range(num_server): for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling', ['csc', 'coo']))
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
fanout = 3 fanout = 3
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout, block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': [0, 10, 99, 66, 124, 208]}) nodes={'n3': [0, 10, 99, 66, 124, 208]},
etype_sorted=etype_sorted)
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -832,6 +834,7 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -832,6 +834,7 @@ def test_rpc_sampling_shuffle(num_server):
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server, etype_sorted=True)
check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment