".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5440cbd34ea5a0f370b7ec6a6ed4d6b5fdbcf67a"
Unverified Commit 4ee0a8bd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] return global eids from GB sampling on homograph (#7085)

parent badeaf19
...@@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt( ...@@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt(
# [Rui][TODO] Support multiple fanouts. # [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout." assert fanout.numel() == 1, "Expect a single fanout."
subgraph = g._sample_neighbors(nodes, fanout) return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
# 3. Map local node IDs to global node IDs. # 3. Map local node IDs to global node IDs.
local_src = subgraph.indices local_src = subgraph.indices
...@@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt( ...@@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src] global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst] global_dst = global_nid_mapping[local_dst]
# [Rui][TODO] edge IDs are not supported yet. global_eids = None
if return_eids:
global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]
return LocalSampledGraph( return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge global_src, global_dst, global_eids, subgraph.type_per_edge
) )
......
...@@ -75,6 +75,7 @@ def start_sample_client_shuffle( ...@@ -75,6 +75,7 @@ def start_sample_client_shuffle(
orig_nid, orig_nid,
orig_eid, orig_eid,
use_graphbolt=False, use_graphbolt=False,
return_eids=False,
): ):
os.environ["DGL_GROUP_ID"] = str(group_id) os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None gpb = None
...@@ -95,7 +96,7 @@ def start_sample_client_shuffle( ...@@ -95,7 +96,7 @@ def start_sample_client_shuffle(
dst = orig_nid[dst] dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes() assert sampled_graph.num_nodes() == g.num_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst))) assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
if use_graphbolt: if use_graphbolt and not return_eids:
assert ( assert (
dgl.EID not in sampled_graph.edata dgl.EID not in sampled_graph.edata
), "EID should not be in sampled graph if use_graphbolt=True." ), "EID should not be in sampled graph if use_graphbolt=True."
...@@ -391,7 +392,7 @@ def test_rpc_sampling(): ...@@ -391,7 +392,7 @@ def test_rpc_sampling():
def check_rpc_sampling_shuffle( def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
): ):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle( ...@@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle(
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle( ...@@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle(
orig_nids, orig_nids,
orig_eids, orig_eids,
use_graphbolt, use_graphbolt,
return_eids,
), ),
) )
p.start() p.start()
...@@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True]) @pytest.mark.parametrize("use_graphbolt", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt): @pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle( check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=use_graphbolt Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment