Unverified Commit ee8b7b39 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable GB sampling on heterograph (#7087)

parent a2e1c796
"""A set of graph services of getting subgraphs from DistGraph""" """A set of graph services of getting subgraphs from DistGraph"""
import os
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
...@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
idtype=g.idtype, idtype=g.idtype,
) )
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID]) # For DGL partitions, the global edge IDs are always stored in the edata.
src, dst = frontier.edges() # For GraphBolt partitions, the edge type IDs are always stored in the
# edata. As for the edge IDs, they are stored in the edata if the graph is
# partitioned with `store_eids=True`. Otherwise, the edge IDs are not
# stored.
etype_ids, type_wise_eids = (
gpb.map_to_per_etype(frontier.edata[EID])
if EID in frontier.edata
else (frontier.edata[ETYPE], None)
)
etype_ids, idx = F.sort_1d(etype_ids) etype_ids, idx = F.sort_1d(etype_ids)
if type_wise_eids is not None:
type_wise_eids = F.gather_row(type_wise_eids, idx)
# Sort the edges by their edge types.
src, dst = frontier.edges()
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx) src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx) src_ntype_ids, src = gpb.map_to_per_ntype(src)
_, src = gpb.map_to_per_ntype(src) dst_ntype_ids, dst = gpb.map_to_per_ntype(dst)
_, dst = gpb.map_to_per_ntype(dst)
data_dict = dict() data_dict = dict()
edge_ids = {} edge_ids = {}
for etid, etype in enumerate(g.canonical_etypes): for etid, etype in enumerate(g.canonical_etypes):
src_ntype, _, dst_ntype = etype
src_ntype_id = g.get_ntype_id(src_ntype)
dst_ntype_id = g.get_ntype_id(dst_ntype)
type_idx = etype_ids == etid type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0: if F.sum(type_idx, 0) > 0:
data_dict[etype] = ( data_dict[etype] = (
F.boolean_mask(src, type_idx), F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx), F.boolean_mask(dst, type_idx),
) )
edge_ids[etype] = F.boolean_mask(eid, type_idx) if "DGL_DIST_DEBUG" in os.environ:
assert torch.all(
src_ntype_id == src_ntype_ids[type_idx]
), "source ntype is is not expected."
assert torch.all(
dst_ntype_id == dst_ntype_ids[type_idx]
), "destination ntype is is not expected."
if type_wise_eids is not None:
edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)
hg = heterograph( hg = heterograph(
data_dict, data_dict,
{ntype: g.num_nodes(ntype) for ntype in g.ntypes}, {ntype: g.num_nodes(ntype) for ntype in g.ntypes},
......
...@@ -91,6 +91,9 @@ def start_sample_client_shuffle( ...@@ -91,6 +91,9 @@ def start_sample_client_shuffle(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
) )
assert (
dgl.ETYPE not in sampled_graph.edata
), "Etype should not be in homogeneous sampled graph."
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
src = orig_nid[src] src = orig_nid[src]
dst = orig_nid[dst] dst = orig_nid[dst]
...@@ -460,22 +463,36 @@ def check_rpc_sampling_shuffle( ...@@ -460,22 +463,36 @@ def check_rpc_sampling_shuffle(
assert p.exitcode == 0 assert p.exitcode == 0
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes): def start_hetero_sample_client(
rank,
tmpdir,
disable_shared_mem,
nodes,
use_graphbolt=False,
return_eids=False,
):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition( _, _, _, gpb, _, _, _ = load_partition(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
try: try:
sampled_graph = sample_neighbors(dist_graph, nodes, 3) # Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
...@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client( ...@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client(
return block, gpb return block, gpb
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): def check_rpc_hetero_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero() g = create_random_hetero()
...@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
block, gpb = start_hetero_sample_client( block, gpb = start_hetero_sample_client(
0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]} 0,
tmpdir,
num_server > 1,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling")
for p in pserver_list: for p in pserver_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
...@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling. # These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
# Check the node Ids and edge Ids. # Check the node Ids and edge Ids.
...@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype): ...@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype):
return deg return deg
def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): def check_rpc_hetero_sampling_empty_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(empty=True) g = create_random_hetero(empty=True)
...@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): ...@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): ...@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): ...@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nids["n3"], "n3") deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_sample_client( block, gpb = start_hetero_sample_client(
0, tmpdir, num_server > 1, nodes={"n3": empty_nids} 0,
tmpdir,
num_server > 1,
nodes={"n3": empty_nids},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling")
for p in pserver_list: for p in pserver_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
...@@ -759,21 +813,35 @@ def create_random_bipartite(): ...@@ -759,21 +813,35 @@ def create_random_bipartite():
return g return g
def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes): def start_bipartite_sample_client(
rank,
tmpdir,
disable_shared_mem,
nodes,
use_graphbolt=False,
return_eids=False,
):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition( _, _, _, gpb, _, _, _ = load_partition(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
sampled_graph = sample_neighbors(dist_graph, nodes, 3) # Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0: if sampled_graph.num_edges() > 0:
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client() dgl.distributed.exit_client()
return block, gpb return block, gpb
...@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client( ...@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client(
return block, gpb return block, gpb
def check_rpc_bipartite_sampling_empty(tmpdir, num_server): def check_rpc_bipartite_sampling_empty(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_neighbors() which yields empty sample results""" """sample on bipartite via sample_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): ...@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): ...@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): ...@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
deg = get_degrees(g, orig_nids["game"], "game") deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
block, _ = start_bipartite_sample_client( block, _ = start_bipartite_sample_client(
0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]} 0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
...@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): ...@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
assert len(block.etypes) == len(g.etypes) assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): def check_rpc_bipartite_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_neighbors() which yields non-empty sample results""" """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nid_map["game"], "game") deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0) nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_sample_client( block, gpb = start_bipartite_sample_client(
0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]} 0,
tmpdir,
num_server > 1,
nodes={"game": nids, "user": [0]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling. # These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
# Check the node Ids and edge Ids. # Check the node Ids and edge Ids.
...@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids): ...@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_empty_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
...@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server): ...@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_empty_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_empty(
Path(tmpdirname), num_server, use_graphbolt, return_eids
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_bipartite_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt, return_eids
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment