Unverified Commit 6459a688 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable GB sampling on homograph (#7061)

parent 9273387e
......@@ -3,7 +3,9 @@ from collections import namedtuple
import numpy as np
from .. import backend as F
import torch
from .. import backend as F, graphbolt as gb
from ..base import EID, NID
from ..convert import graph, heterograph
from ..sampling import (
......@@ -65,6 +67,81 @@ class FindEdgeResponse(Response):
return self.global_src, self.global_dst, self.order_id
def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, prob=None, replace=False
):
"""Sample from local partition via graphbolt.
The input nodes use global IDs. We need to map the global node IDs to local
node IDs, perform sampling and map the sampled results to the global IDs
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
The local partition.
gpb : GraphPartitionBook
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
If True, sample with replacement.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
"""
# 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid)
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
assert (
prob is None
), "DistGraphBolt does not support sampling with probability."
assert (
not replace
), "DistGraphBolt does not support sampling with replacement."
# Sanity checks.
assert isinstance(
g, gb.FusedCSCSamplingGraph
), "Expect a FusedCSCSamplingGraph."
assert isinstance(nodes, torch.Tensor), "Expect a tensor of nodes."
if isinstance(fanout, int):
fanout = torch.LongTensor([fanout])
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."
subgraph = g._sample_neighbors(nodes, fanout)
# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
local_dst = torch.repeat_interleave(
subgraph.original_column_node_ids, torch.diff(subgraph.indptr)
)
global_nid_mapping = g.node_attributes[NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst, subgraph.type_per_edge
def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
):
......@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class SamplingRequest(Request):
"""Sampling Request"""
def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False):
def __init__(
self,
nodes,
fan_out,
edge_dir="in",
prob=None,
replace=False,
use_graphbolt=False,
):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
self.replace = replace
self.fan_out = fan_out
self.use_graphbolt = use_graphbolt
def __setstate__(self, state):
(
......@@ -226,6 +312,7 @@ class SamplingRequest(Request):
self.prob,
self.replace,
self.fan_out,
self.use_graphbolt,
) = state
def __getstate__(self):
......@@ -235,6 +322,7 @@ class SamplingRequest(Request):
self.prob,
self.replace,
self.fan_out,
self.use_graphbolt,
)
def process_request(self, server_state):
......@@ -245,6 +333,16 @@ class SamplingRequest(Request):
prob = [kv_store.data_store[self.prob]]
else:
prob = None
if self.use_graphbolt:
global_src, global_dst, etype_ids = _sample_neighbors_graphbolt(
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, etype_ids)
global_src, global_dst, global_eids = _sample_neighbors(
local_g,
partition_book,
......@@ -449,12 +547,13 @@ def merge_graphs(res_list, num_nodes):
eids.append(res.global_eids)
src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0)
eid_tensor = F.cat(eids, 0)
eid_tensor = None if eids[0] is None else F.cat(eids, 0)
else:
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
if eid_tensor is not None:
g.edata[EID] = eid_tensor
return g
......@@ -491,6 +590,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
......@@ -721,7 +821,15 @@ def sample_etype_neighbors(
return frontier
def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
def sample_neighbors(
g,
nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
use_graphbolt=False,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
-------
......@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
else:
_prob = None
return SamplingRequest(
node_ids, fanout, edge_dir=edge_dir, prob=_prob, replace=replace
node_ids,
fanout,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
use_graphbolt=use_graphbolt,
)
def local_access(local_g, partition_book, local_nids):
# See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None
if use_graphbolt:
return _sample_neighbors_graphbolt(
local_g,
partition_book,
local_nids,
fanout,
prob=_prob,
replace=replace,
)
return _sample_neighbors(
local_g,
partition_book,
......
......@@ -31,6 +31,7 @@ def start_server(
disable_shared_mem,
graph_name,
graph_format=["csc", "coo"],
use_graphbolt=False,
):
g = DistGraphServer(
rank,
......@@ -40,6 +41,7 @@ def start_server(
tmpdir / (graph_name + ".json"),
disable_shared_mem=disable_shared_mem,
graph_format=graph_format,
use_graphbolt=use_graphbolt,
)
g.start()
......@@ -72,6 +74,7 @@ def start_sample_client_shuffle(
group_id,
orig_nid,
orig_eid,
use_graphbolt=False,
):
os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None
......@@ -80,14 +83,23 @@ def start_sample_client_shuffle(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
)
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
if use_graphbolt:
assert (
dgl.EID not in sampled_graph.edata
), "EID should not be in sampled graph if use_graphbolt=True."
else:
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
......@@ -378,7 +390,9 @@ def test_rpc_sampling():
check_rpc_sampling(Path(tmpdirname), 1)
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = CitationGraphDataset("cora")[0]
......@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
)
pserver_list = []
......@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
......@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
group_id,
orig_nids,
orig_eids,
use_graphbolt,
),
)
p.start()
......@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server):
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=True
)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment