"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7fb481f840b5d73982cafd1affe89f21a5c0b20b"
Unverified Commit 6459a688 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable GB sampling on homograph (#7061)

parent 9273387e
...@@ -3,7 +3,9 @@ from collections import namedtuple ...@@ -3,7 +3,9 @@ from collections import namedtuple
import numpy as np import numpy as np
from .. import backend as F import torch
from .. import backend as F, graphbolt as gb
from ..base import EID, NID from ..base import EID, NID
from ..convert import graph, heterograph from ..convert import graph, heterograph
from ..sampling import ( from ..sampling import (
...@@ -65,6 +67,81 @@ class FindEdgeResponse(Response): ...@@ -65,6 +67,81 @@ class FindEdgeResponse(Response):
return self.global_src, self.global_dst, self.order_id return self.global_src, self.global_dst, self.order_id
def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, prob=None, replace=False
):
"""Sample from local partition via graphbolt.
The input nodes use global IDs. We need to map the global node IDs to local
node IDs, perform sampling and map the sampled results to the global IDs
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
The local partition.
gpb : GraphPartitionBook
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
If True, sample with replacement.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
"""
# 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid)
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
assert (
prob is None
), "DistGraphBolt does not support sampling with probability."
assert (
not replace
), "DistGraphBolt does not support sampling with replacement."
# Sanity checks.
assert isinstance(
g, gb.FusedCSCSamplingGraph
), "Expect a FusedCSCSamplingGraph."
assert isinstance(nodes, torch.Tensor), "Expect a tensor of nodes."
if isinstance(fanout, int):
fanout = torch.LongTensor([fanout])
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."
subgraph = g._sample_neighbors(nodes, fanout)
# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
local_dst = torch.repeat_interleave(
subgraph.original_column_node_ids, torch.diff(subgraph.indptr)
)
global_nid_mapping = g.node_attributes[NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst, subgraph.type_per_edge
def _sample_neighbors( def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
): ):
...@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes): ...@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class SamplingRequest(Request): class SamplingRequest(Request):
"""Sampling Request""" """Sampling Request"""
def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False): def __init__(
self,
nodes,
fan_out,
edge_dir="in",
prob=None,
replace=False,
use_graphbolt=False,
):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob self.prob = prob
self.replace = replace self.replace = replace
self.fan_out = fan_out self.fan_out = fan_out
self.use_graphbolt = use_graphbolt
def __setstate__(self, state): def __setstate__(self, state):
( (
...@@ -226,6 +312,7 @@ class SamplingRequest(Request): ...@@ -226,6 +312,7 @@ class SamplingRequest(Request):
self.prob, self.prob,
self.replace, self.replace,
self.fan_out, self.fan_out,
self.use_graphbolt,
) = state ) = state
def __getstate__(self): def __getstate__(self):
...@@ -235,6 +322,7 @@ class SamplingRequest(Request): ...@@ -235,6 +322,7 @@ class SamplingRequest(Request):
self.prob, self.prob,
self.replace, self.replace,
self.fan_out, self.fan_out,
self.use_graphbolt,
) )
def process_request(self, server_state): def process_request(self, server_state):
...@@ -245,6 +333,16 @@ class SamplingRequest(Request): ...@@ -245,6 +333,16 @@ class SamplingRequest(Request):
prob = [kv_store.data_store[self.prob]] prob = [kv_store.data_store[self.prob]]
else: else:
prob = None prob = None
if self.use_graphbolt:
global_src, global_dst, etype_ids = _sample_neighbors_graphbolt(
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, etype_ids)
global_src, global_dst, global_eids = _sample_neighbors( global_src, global_dst, global_eids = _sample_neighbors(
local_g, local_g,
partition_book, partition_book,
...@@ -449,13 +547,14 @@ def merge_graphs(res_list, num_nodes): ...@@ -449,13 +547,14 @@ def merge_graphs(res_list, num_nodes):
eids.append(res.global_eids) eids.append(res.global_eids)
src_tensor = F.cat(srcs, 0) src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0) dst_tensor = F.cat(dsts, 0)
eid_tensor = F.cat(eids, 0) eid_tensor = None if eids[0] is None else F.cat(eids, 0)
else: else:
src_tensor = res_list[0].global_src src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes) g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
g.edata[EID] = eid_tensor if eid_tensor is not None:
g.edata[EID] = eid_tensor
return g return g
...@@ -491,7 +590,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -491,7 +590,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
""" """
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
nodes = toindex(nodes).tousertensor() if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes) partition_id = partition_book.nid2partid(nodes)
local_nids = None local_nids = None
for pid in range(partition_book.num_partitions()): for pid in range(partition_book.num_partitions()):
...@@ -721,7 +821,15 @@ def sample_etype_neighbors( ...@@ -721,7 +821,15 @@ def sample_etype_neighbors(
return frontier return frontier
def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): def sample_neighbors(
g,
nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
use_graphbolt=False,
):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): ...@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
For sampling without replacement, if fanout > the number of neighbors, all the For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected. neighbors are sampled. If fanout == -1, all neighbors are collected.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns Returns
------- -------
...@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): ...@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
else: else:
_prob = None _prob = None
return SamplingRequest( return SamplingRequest(
node_ids, fanout, edge_dir=edge_dir, prob=_prob, replace=replace node_ids,
fanout,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
use_graphbolt=use_graphbolt,
) )
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
# See NOTE 1 # See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None _prob = [g.edata[prob].local_partition] if prob is not None else None
if use_graphbolt:
return _sample_neighbors_graphbolt(
local_g,
partition_book,
local_nids,
fanout,
prob=_prob,
replace=replace,
)
return _sample_neighbors( return _sample_neighbors(
local_g, local_g,
partition_book, partition_book,
......
...@@ -31,6 +31,7 @@ def start_server( ...@@ -31,6 +31,7 @@ def start_server(
disable_shared_mem, disable_shared_mem,
graph_name, graph_name,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
use_graphbolt=False,
): ):
g = DistGraphServer( g = DistGraphServer(
rank, rank,
...@@ -40,6 +41,7 @@ def start_server( ...@@ -40,6 +41,7 @@ def start_server(
tmpdir / (graph_name + ".json"), tmpdir / (graph_name + ".json"),
disable_shared_mem=disable_shared_mem, disable_shared_mem=disable_shared_mem,
graph_format=graph_format, graph_format=graph_format,
use_graphbolt=use_graphbolt,
) )
g.start() g.start()
...@@ -72,6 +74,7 @@ def start_sample_client_shuffle( ...@@ -72,6 +74,7 @@ def start_sample_client_shuffle(
group_id, group_id,
orig_nid, orig_nid,
orig_eid, orig_eid,
use_graphbolt=False,
): ):
os.environ["DGL_GROUP_ID"] = str(group_id) os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None gpb = None
...@@ -80,17 +83,26 @@ def start_sample_client_shuffle( ...@@ -80,17 +83,26 @@ def start_sample_client_shuffle(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph(
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
)
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
src = orig_nid[src] src = orig_nid[src]
dst = orig_nid[dst] dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes() assert sampled_graph.num_nodes() == g.num_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst))) assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst) if use_graphbolt:
eids1 = orig_eid[sampled_graph.edata[dgl.EID]] assert (
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids)) dgl.EID not in sampled_graph.edata
), "EID should not be in sampled graph if use_graphbolt=True."
else:
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None): def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
...@@ -378,7 +390,9 @@ def test_rpc_sampling(): ...@@ -378,7 +390,9 @@ def test_rpc_sampling():
check_rpc_sampling(Path(tmpdirname), 1) check_rpc_sampling(Path(tmpdirname), 1)
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
) )
pserver_list = [] pserver_list = []
...@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server > 1, num_server > 1,
"test_sampling", "test_sampling",
["csc", "coo"], ["csc", "coo"],
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
group_id, group_id,
orig_nids, orig_nids,
orig_eids, orig_eids,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server):
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=True
)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and # [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now. # root cause is unknown. Let's disable them for now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment