Unverified Commit 5f44a4ef authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

Fix #1453 again (#2169)

* Fix reincarnation of #1453

* fix
parent 567c5acf
...@@ -1713,21 +1713,14 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -1713,21 +1713,14 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
raise ValueError( raise ValueError(
'Graph has more than one node type; please specify a dict for dst_nodes.') 'Graph has more than one node type; please specify a dict for dst_nodes.')
dst_nodes = {g.ntypes[0]: dst_nodes} dst_nodes = {g.ntypes[0]: dst_nodes}
dst_nodes = {
ntype: utils.toindex(nodes, g._idtype_str).tousertensor()
for ntype, nodes in dst_nodes.items()}
# dst_nodes is now a dict dst_node_ids = [
dst_nodes_nd = [] utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor()
for ntype in g.ntypes: for ntype in g.ntypes]
nodes = dst_nodes.get(ntype, None) dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids]
if nodes is not None:
dst_nodes_nd.append(F.to_dgl_nd(nodes))
else:
dst_nodes_nd.append(nd.NULL[g._idtype_str])
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock( new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src) g._graph, dst_node_ids_nd, include_dst_in_src)
# The new graph duplicates the original node types to SRC and DST sets. # The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = (g.ntypes, g.ntypes) new_ntypes = (g.ntypes, g.ntypes)
...@@ -1735,7 +1728,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -1735,7 +1728,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
assert new_graph.is_unibipartite # sanity check assert new_graph.is_unibipartite # sanity check
src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_nd] src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_nd]
dst_node_ids = [F.from_dgl_nd(dst) for dst in dst_nodes_nd]
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd] edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]
node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids) node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment