Unverified Commit 55af15d4 authored by caojy1998's avatar caojy1998 Committed by GitHub
Browse files

[BugFix]Use Python API for relabel_nodes (#5937)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-6-31.ap-northeast-1.compute.internal>
parent 9249be28
...@@ -163,6 +163,8 @@ def node_subgraph( ...@@ -163,6 +163,8 @@ def node_subgraph(
] ]
sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes) sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
if not relabel_nodes:
sgi = graph._graph.edge_subgraph(induced_edges, True)
# (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same # (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same
# bug in #1453. # bug in #1453.
induced_nodes_or_device = induced_nodes if relabel_nodes else device induced_nodes_or_device = induced_nodes if relabel_nodes else device
......
...@@ -404,7 +404,6 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph") ...@@ -404,7 +404,6 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> vids = args[1]; List<Value> vids = args[1];
bool relabel_nodes = args[2]; bool relabel_nodes = args[2];
CHECK(relabel_nodes) << "Node subgraph only supports relabel_nodes=True.";
std::vector<IdArray> vid_vec; std::vector<IdArray> vid_vec;
vid_vec.reserve(vids.size()); vid_vec.reserve(vids.size());
for (Value val : vids) { for (Value val : vids) {
......
...@@ -54,20 +54,32 @@ def test_edge_subgraph(): ...@@ -54,20 +54,32 @@ def test_edge_subgraph():
sg.edata["h"] = F.arange(0, sg.num_edges()) sg.edata["h"] = F.arange(0, sg.num_edges())
def test_subgraph(): @pytest.mark.parametrize("relabel_nodes", [True, False])
def test_subgraph_relabel_nodes(relabel_nodes):
g = generate_graph() g = generate_graph()
h = g.ndata["h"] h = g.ndata["h"]
l = g.edata["l"] l = g.edata["l"]
nid = [0, 2, 3, 6, 7, 9] nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid) sg = g.subgraph(nid, relabel_nodes=relabel_nodes)
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16} eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(F.asnumpy(sg.edata[dgl.EID])) == eid assert set(F.asnumpy(sg.edata[dgl.EID])) == eid
eid = sg.edata[dgl.EID] eid = sg.edata[dgl.EID]
# the subgraph is empty initially except for NID/EID field # the subgraph is empty initially except for EID field
assert len(sg.ndata) == 2 # the subgraph is empty initially except for NID field if relabel_nodes
if relabel_nodes:
assert len(sg.ndata) == 2
assert len(sg.edata) == 2 assert len(sg.edata) == 2
sh = sg.ndata["h"] sh = sg.ndata["h"]
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh) # The node number is not reduced if relabel_node=False.
# The subgraph keeps the same node information as the original graph.
if relabel_nodes:
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
else:
assert F.allclose(
F.gather_row(h, F.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])), sh
)
# The s,d,eid means the source node, destination node and edge id of the subgraph.
# The edges labeled 1 are those selected by the subgraph.
""" """
s, d, eid s, d, eid
0, 1, 0 0, 1, 0
...@@ -91,7 +103,10 @@ def test_subgraph(): ...@@ -91,7 +103,10 @@ def test_subgraph():
assert F.allclose(F.gather_row(l, eid), sg.edata["l"]) assert F.allclose(F.gather_row(l, eid), sg.edata["l"])
# update the node/edge features on the subgraph should NOT # update the node/edge features on the subgraph should NOT
# reflect to the parent graph. # reflect to the parent graph.
sg.ndata["h"] = F.zeros((6, D)) if relabel_nodes:
sg.ndata["h"] = F.zeros((6, D))
else:
sg.ndata["h"] = F.zeros((10, D))
assert F.allclose(h, g.ndata["h"]) assert F.allclose(h, g.ndata["h"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment