Unverified Commit f3af2a9f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] return eids together with etype_ids in sampling (#7084)

parent 346197c4
......@@ -6,7 +6,7 @@ import numpy as np
import torch
from .. import backend as F, graphbolt as gb
from ..base import EID, NID
from ..base import EID, ETYPE, NID
from ..convert import graph, heterograph
from ..sampling import (
sample_etype_neighbors as local_sample_etype_neighbors,
......@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662
class SubgraphResponse(Response):
"""The response for sampling and in_subgraph"""
def __init__(self, global_src, global_dst, global_eids):
def __init__(
self, global_src, global_dst, *, global_eids=None, etype_ids=None
):
self.global_src = global_src
self.global_dst = global_dst
self.global_eids = global_eids
self.etype_ids = etype_ids
def __setstate__(self, state):
self.global_src, self.global_dst, self.global_eids = state
(
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
) = state
def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids
return (
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
)
class FindEdgeResponse(Response):
......@@ -68,7 +81,7 @@ class FindEdgeResponse(Response):
def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, prob=None, replace=False
g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False
):
"""Sample from local partition via graphbolt.
......@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
......@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
......@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
assert (
edge_dir == "in"
), f"GraphBolt only supports inbound edge sampling but got {edge_dir}."
# 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid)
......@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst, subgraph.type_per_edge
# [Rui][TODO] edge IDs are not supported yet.
return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge
)
def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
def _sample_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
fan_out,
edge_dir="in",
prob=None,
replace=False,
):
"""Sample from local partition.
......@@ -170,7 +196,38 @@ def _sample_neighbors(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)
def _sample_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt if use_graphbolt else _sample_neighbors_dgl
)
return func(*args, **kwargs)
def _sample_etype_neighbors(
......@@ -211,7 +268,7 @@ def _sample_etype_neighbors(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)
def _find_edges(local_g, partition_book, seed_edges):
......@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)
# --- NOTE 1 ---
......@@ -333,26 +390,22 @@ class SamplingRequest(Request):
prob = [kv_store.data_store[self.prob]]
else:
prob = None
if self.use_graphbolt:
global_src, global_dst, etype_ids = _sample_neighbors_graphbolt(
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, etype_ids)
global_src, global_dst, global_eids = _sample_neighbors(
res = _sample_neighbors(
self.use_graphbolt,
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
self.edge_dir,
prob,
self.replace,
edge_dir=self.edge_dir,
prob=prob,
replace=self.replace,
)
return SubgraphResponse(
res.global_src,
res.global_dst,
global_eids=res.global_eids,
etype_ids=res.etype_ids,
)
return SubgraphResponse(global_src, global_dst, global_eids)
class SamplingRequestEtype(Request):
......@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request):
]
else:
probs = None
global_src, global_dst, global_eids = _sample_etype_neighbors(
res = _sample_etype_neighbors(
local_g,
partition_book,
self.seed_nodes,
......@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request):
self.replace,
self.etype_sorted,
)
return SubgraphResponse(global_src, global_dst, global_eids)
return SubgraphResponse(
res.global_src,
res.global_dst,
global_eids=res.global_eids,
etype_ids=res.etype_ids,
)
class EdgesRequest(Request):
......@@ -532,7 +590,7 @@ class InSubgraphRequest(Request):
global_src, global_dst, global_eids = _in_subgraph(
local_g, partition_book, self.seed_nodes
)
return SubgraphResponse(global_src, global_dst, global_eids)
return SubgraphResponse(global_src, global_dst, global_eids=global_eids)
def merge_graphs(res_list, num_nodes):
......@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
srcs = []
dsts = []
eids = []
etype_ids = []
for res in res_list:
srcs.append(res.global_src)
dsts.append(res.global_dst)
eids.append(res.global_eids)
etype_ids.append(res.etype_ids)
src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0)
eid_tensor = None if eids[0] is None else F.cat(eids, 0)
etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0)
else:
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
etype_id_tensor = res_list[0].etype_ids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
if eid_tensor is not None:
g.edata[EID] = eid_tensor
if etype_id_tensor is not None:
g.edata[ETYPE] = etype_id_tensor
return g
LocalSampledGraph = namedtuple(
"LocalSampledGraph", "global_src global_dst global_eids"
LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg
"LocalSampledGraph",
"global_src global_dst global_eids etype_ids",
defaults=(None, None, None, None),
)
......@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = local_access(
g.local_partition, partition_book, local_nids
)
res_list.append(LocalSampledGraph(src, dst, eids))
res = local_access(g.local_partition, partition_book, local_nids)
res_list.append(res)
# receive responses from remote machines.
if msgseq2pos is not None:
......@@ -916,23 +980,15 @@ def sample_neighbors(
def local_access(local_g, partition_book, local_nids):
# See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None
if use_graphbolt:
return _sample_neighbors_graphbolt(
local_g,
partition_book,
local_nids,
fanout,
prob=_prob,
replace=replace,
)
return _sample_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
fanout,
edge_dir,
_prob,
replace,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
......
import multiprocessing as mp
import os
import random
import sys
import tempfile
import time
import traceback
import unittest
......@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
# Wait non shared memory graph store
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="Not support tensorflow for now",
)
@unittest.skipIf(
dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_sampling_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt):
reset_envs()
import tempfile
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=True
Path(tmpdirname), num_server, use_graphbolt=use_graphbolt
)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
# check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname), num_server, ["csc"]
)
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname), num_server, ["csr"]
)
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
"graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname), num_server, ["csc", "coo"]
Path(tmpdirname), num_server, graph_formats=graph_formats
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_empty_shuffle(
Path(tmpdirname), num_server
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment