Unverified Commit 4ee0a8bd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] return global eids from GB sampling on homograph (#7085)

parent badeaf19
......@@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt(
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."
subgraph = g._sample_neighbors(nodes, fanout)
return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
......@@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
# [Rui][TODO] edge IDs are not supported yet.
global_eids = None
if return_eids:
global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]
return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge
global_src, global_dst, global_eids, subgraph.type_per_edge
)
......
......@@ -75,6 +75,7 @@ def start_sample_client_shuffle(
orig_nid,
orig_eid,
use_graphbolt=False,
return_eids=False,
):
os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None
......@@ -95,7 +96,7 @@ def start_sample_client_shuffle(
dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
if use_graphbolt:
if use_graphbolt and not return_eids:
assert (
dgl.EID not in sampled_graph.edata
), "EID should not be in sampled graph if use_graphbolt=True."
......@@ -391,7 +392,7 @@ def test_rpc_sampling():
def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False
tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle(
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle(
orig_nids,
orig_eids,
use_graphbolt,
return_eids,
),
)
p.start()
......@@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt):
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=use_graphbolt
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment