"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f4af03b350136795375dbd913567857a4ce04fd5"
Unverified Commit f3af2a9f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] return eids together with etype_ids in sampling (#7084)

parent 346197c4
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from .. import backend as F, graphbolt as gb from .. import backend as F, graphbolt as gb
from ..base import EID, NID from ..base import EID, ETYPE, NID
from ..convert import graph, heterograph from ..convert import graph, heterograph
from ..sampling import ( from ..sampling import (
sample_etype_neighbors as local_sample_etype_neighbors, sample_etype_neighbors as local_sample_etype_neighbors,
...@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662 ...@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662
class SubgraphResponse(Response): class SubgraphResponse(Response):
"""The response for sampling and in_subgraph""" """The response for sampling and in_subgraph"""
def __init__(self, global_src, global_dst, global_eids): def __init__(
self, global_src, global_dst, *, global_eids=None, etype_ids=None
):
self.global_src = global_src self.global_src = global_src
self.global_dst = global_dst self.global_dst = global_dst
self.global_eids = global_eids self.global_eids = global_eids
self.etype_ids = etype_ids
def __setstate__(self, state): def __setstate__(self, state):
self.global_src, self.global_dst, self.global_eids = state (
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
) = state
def __getstate__(self): def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids return (
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
)
class FindEdgeResponse(Response): class FindEdgeResponse(Response):
...@@ -68,7 +81,7 @@ class FindEdgeResponse(Response): ...@@ -68,7 +81,7 @@ class FindEdgeResponse(Response):
def _sample_neighbors_graphbolt( def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, prob=None, replace=False g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False
): ):
"""Sample from local partition via graphbolt. """Sample from local partition via graphbolt.
...@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt( ...@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
space again. The sampled results are stored in three vectors that store space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs. source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters Parameters
---------- ----------
g : FusedCSCSamplingGraph g : FusedCSCSamplingGraph
...@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt( ...@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
The nodes to sample neighbors from. The nodes to sample neighbors from.
fanout : tensor or int fanout : tensor or int
The number of edges to be sampled for each node. The number of edges to be sampled for each node.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
prob : tensor, optional prob : tensor, optional
The probability associated with each neighboring edge of a node. The probability associated with each neighboring edge of a node.
replace : bool, optional replace : bool, optional
...@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt( ...@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
The source node ID array. The source node ID array.
tensor tensor
The destination node ID array. The destination node ID array.
tensor
The edge type ID array.
tensor tensor
The edge ID array. The edge ID array.
tensor
The edge type ID array.
""" """
assert (
edge_dir == "in"
), f"GraphBolt only supports inbound edge sampling but got {edge_dir}."
# 1. Map global node IDs to local node IDs. # 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid) nodes = gpb.nid2localnid(nodes, gpb.partid)
...@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt( ...@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src] global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst] global_dst = global_nid_mapping[local_dst]
return global_src, global_dst, subgraph.type_per_edge # [Rui][TODO] edge IDs are not supported yet.
return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge
)
def _sample_neighbors( def _sample_neighbors_dgl(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace local_g,
partition_book,
seed_nodes,
fan_out,
edge_dir="in",
prob=None,
replace=False,
): ):
"""Sample from local partition. """Sample from local partition.
...@@ -170,7 +196,38 @@ def _sample_neighbors( ...@@ -170,7 +196,38 @@ def _sample_neighbors(
global_nid_mapping, src global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst) ), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return LocalSampledGraph(global_src, global_dst, global_eids)
def _sample_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt if use_graphbolt else _sample_neighbors_dgl
)
return func(*args, **kwargs)
def _sample_etype_neighbors( def _sample_etype_neighbors(
...@@ -211,7 +268,7 @@ def _sample_etype_neighbors( ...@@ -211,7 +268,7 @@ def _sample_etype_neighbors(
global_nid_mapping, src global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst) ), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return LocalSampledGraph(global_src, global_dst, global_eids)
def _find_edges(local_g, partition_book, seed_edges): def _find_edges(local_g, partition_book, seed_edges):
...@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes): ...@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst] global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return LocalSampledGraph(global_src, global_dst, global_eids)
# --- NOTE 1 --- # --- NOTE 1 ---
...@@ -333,26 +390,22 @@ class SamplingRequest(Request): ...@@ -333,26 +390,22 @@ class SamplingRequest(Request):
prob = [kv_store.data_store[self.prob]] prob = [kv_store.data_store[self.prob]]
else: else:
prob = None prob = None
if self.use_graphbolt: res = _sample_neighbors(
global_src, global_dst, etype_ids = _sample_neighbors_graphbolt( self.use_graphbolt,
local_g, local_g,
partition_book, partition_book,
self.seed_nodes, self.seed_nodes,
self.fan_out, self.fan_out,
prob, edge_dir=self.edge_dir,
self.replace, prob=prob,
replace=self.replace,
) )
return SubgraphResponse(global_src, global_dst, etype_ids) return SubgraphResponse(
global_src, global_dst, global_eids = _sample_neighbors( res.global_src,
local_g, res.global_dst,
partition_book, global_eids=res.global_eids,
self.seed_nodes, etype_ids=res.etype_ids,
self.fan_out,
self.edge_dir,
prob,
self.replace,
) )
return SubgraphResponse(global_src, global_dst, global_eids)
class SamplingRequestEtype(Request): class SamplingRequestEtype(Request):
...@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request): ...@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request):
] ]
else: else:
probs = None probs = None
global_src, global_dst, global_eids = _sample_etype_neighbors( res = _sample_etype_neighbors(
local_g, local_g,
partition_book, partition_book,
self.seed_nodes, self.seed_nodes,
...@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request): ...@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request):
self.replace, self.replace,
self.etype_sorted, self.etype_sorted,
) )
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(
res.global_src,
res.global_dst,
global_eids=res.global_eids,
etype_ids=res.etype_ids,
)
class EdgesRequest(Request): class EdgesRequest(Request):
...@@ -532,7 +590,7 @@ class InSubgraphRequest(Request): ...@@ -532,7 +590,7 @@ class InSubgraphRequest(Request):
global_src, global_dst, global_eids = _in_subgraph( global_src, global_dst, global_eids = _in_subgraph(
local_g, partition_book, self.seed_nodes local_g, partition_book, self.seed_nodes
) )
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids=global_eids)
def merge_graphs(res_list, num_nodes): def merge_graphs(res_list, num_nodes):
...@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes): ...@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
srcs = [] srcs = []
dsts = [] dsts = []
eids = [] eids = []
etype_ids = []
for res in res_list: for res in res_list:
srcs.append(res.global_src) srcs.append(res.global_src)
dsts.append(res.global_dst) dsts.append(res.global_dst)
eids.append(res.global_eids) eids.append(res.global_eids)
etype_ids.append(res.etype_ids)
src_tensor = F.cat(srcs, 0) src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0) dst_tensor = F.cat(dsts, 0)
eid_tensor = None if eids[0] is None else F.cat(eids, 0) eid_tensor = None if eids[0] is None else F.cat(eids, 0)
etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0)
else: else:
src_tensor = res_list[0].global_src src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids eid_tensor = res_list[0].global_eids
etype_id_tensor = res_list[0].etype_ids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes) g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
if eid_tensor is not None: if eid_tensor is not None:
g.edata[EID] = eid_tensor g.edata[EID] = eid_tensor
if etype_id_tensor is not None:
g.edata[ETYPE] = etype_id_tensor
return g return g
LocalSampledGraph = namedtuple( LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg
"LocalSampledGraph", "global_src global_dst global_eids" "LocalSampledGraph",
"global_src global_dst global_eids etype_ids",
defaults=(None, None, None, None),
) )
...@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition. # sample neighbors for the nodes in the local partition.
res_list = [] res_list = []
if local_nids is not None: if local_nids is not None:
src, dst, eids = local_access( res = local_access(g.local_partition, partition_book, local_nids)
g.local_partition, partition_book, local_nids res_list.append(res)
)
res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines. # receive responses from remote machines.
if msgseq2pos is not None: if msgseq2pos is not None:
...@@ -916,24 +980,16 @@ def sample_neighbors( ...@@ -916,24 +980,16 @@ def sample_neighbors(
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
# See NOTE 1 # See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None _prob = [g.edata[prob].local_partition] if prob is not None else None
if use_graphbolt: return _sample_neighbors(
return _sample_neighbors_graphbolt( use_graphbolt,
local_g, local_g,
partition_book, partition_book,
local_nids, local_nids,
fanout, fanout,
edge_dir=edge_dir,
prob=_prob, prob=_prob,
replace=replace, replace=replace,
) )
return _sample_neighbors(
local_g,
partition_book,
local_nids,
fanout,
edge_dir,
_prob,
replace,
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous: if not gpb.is_homogeneous:
......
import multiprocessing as mp import multiprocessing as mp
import os import os
import random import random
import sys import tempfile
import time import time
import traceback import traceback
import unittest import unittest
...@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
assert np.all(F.asnumpy(orig_dst1) == orig_dst) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
# Wait non shared memory graph store
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="Not support tensorflow for now",
)
@unittest.skipIf(
dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_sampling_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt):
reset_envs() reset_envs()
import tempfile
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle( check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=True Path(tmpdirname), num_server, use_graphbolt=use_graphbolt
) )
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now. @pytest.mark.parametrize("num_server", [1])
# check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2) def test_rpc_hetero_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname), num_server, ["csc"] @pytest.mark.parametrize("num_server", [1])
) @pytest.mark.parametrize(
check_rpc_hetero_etype_sampling_shuffle( "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
Path(tmpdirname), num_server, ["csr"] )
) def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_shuffle( check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname), num_server, ["csc", "coo"] Path(tmpdirname), num_server, graph_formats=graph_formats
) )
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_empty_shuffle( check_rpc_hetero_etype_sampling_empty_shuffle(
Path(tmpdirname), num_server Path(tmpdirname), num_server
) )
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server) check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment