Unverified Commit ee8b7b39 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable GB sampling on heterograph (#7087)

parent a2e1c796
"""A set of graph services of getting subgraphs from DistGraph"""
import os
from collections import namedtuple
import numpy as np
......@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
idtype=g.idtype,
)
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
# For DGL partitions, the global edge IDs are always stored in the edata.
# For GraphBolt partitions, the edge type IDs are always stored in the
# edata. As for the edge IDs, they are stored in the edata if the graph is
# partitioned with `store_eids=True`. Otherwise, the edge IDs are not
# stored.
etype_ids, type_wise_eids = (
gpb.map_to_per_etype(frontier.edata[EID])
if EID in frontier.edata
else (frontier.edata[ETYPE], None)
)
etype_ids, idx = F.sort_1d(etype_ids)
if type_wise_eids is not None:
type_wise_eids = F.gather_row(type_wise_eids, idx)
# Sort the edges by their edge types.
src, dst = frontier.edges()
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
src_ntype_ids, src = gpb.map_to_per_ntype(src)
dst_ntype_ids, dst = gpb.map_to_per_ntype(dst)
data_dict = dict()
edge_ids = {}
for etid, etype in enumerate(g.canonical_etypes):
src_ntype, _, dst_ntype = etype
src_ntype_id = g.get_ntype_id(src_ntype)
dst_ntype_id = g.get_ntype_id(dst_ntype)
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[etype] = (
F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx),
)
edge_ids[etype] = F.boolean_mask(eid, type_idx)
if "DGL_DIST_DEBUG" in os.environ:
assert torch.all(
src_ntype_id == src_ntype_ids[type_idx]
), "source ntype is is not expected."
assert torch.all(
dst_ntype_id == dst_ntype_ids[type_idx]
), "destination ntype is is not expected."
if type_wise_eids is not None:
edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)
hg = heterograph(
data_dict,
{ntype: g.num_nodes(ntype) for ntype in g.ntypes},
......
......@@ -91,6 +91,9 @@ def start_sample_client_shuffle(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
)
assert (
dgl.ETYPE not in sampled_graph.edata
), "Etype should not be in homogeneous sampled graph."
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
......@@ -460,23 +463,37 @@ def check_rpc_sampling_shuffle(
assert p.exitcode == 0
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
def start_hetero_sample_client(
rank,
tmpdir,
disable_shared_mem,
nodes,
use_graphbolt=False,
return_eids=False,
):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
sampled_graph = sample_neighbors(dist_graph, nodes, 3)
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e:
print(traceback.format_exc())
block = None
......@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client(
return block, gpb
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
def check_rpc_hetero_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero()
......@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)
block, gpb = start_hetero_sample_client(
0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]}
0,
tmpdir,
num_server > 1,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
p.join()
assert p.exitcode == 0
......@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
# Check the node Ids and edge Ids.
......@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype):
return deg
def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
def check_rpc_hetero_sampling_empty_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(empty=True)
......@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_sample_client(
0, tmpdir, num_server > 1, nodes={"n3": empty_nids}
0,
tmpdir,
num_server > 1,
nodes={"n3": empty_nids},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
p.join()
assert p.exitcode == 0
......@@ -759,22 +813,36 @@ def create_random_bipartite():
return g
def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes):
def start_bipartite_sample_client(
rank,
tmpdir,
disable_shared_mem,
nodes,
use_graphbolt=False,
return_eids=False,
):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data
if gpb is None:
gpb = dist_graph.get_partition_book()
sampled_graph = sample_neighbors(dist_graph, nodes, 3)
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client()
return block, gpb
......@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client(
return block, gpb
def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
def check_rpc_bipartite_sampling_empty(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0)
block, _ = start_bipartite_sample_client(
0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
......@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
def check_rpc_bipartite_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_sample_client(
0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]}
0,
tmpdir,
num_server > 1,
nodes={"game": nids, "user": [0]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
......@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
# Check the node Ids and edge Ids.
......@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1])
......@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_empty(
Path(tmpdirname), num_server, use_graphbolt, return_eids
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt, return_eids
)
@pytest.mark.parametrize("num_server", [1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment