Unverified Commit 5f44a4ef authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

Fix #1453 again (#2169)

* Fix reincarnation of #1453

* fix
parent 567c5acf
......@@ -1713,21 +1713,14 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
raise ValueError(
'Graph has more than one node type; please specify a dict for dst_nodes.')
dst_nodes = {g.ntypes[0]: dst_nodes}
dst_nodes = {
ntype: utils.toindex(nodes, g._idtype_str).tousertensor()
for ntype, nodes in dst_nodes.items()}
# dst_nodes is now a dict
dst_nodes_nd = []
for ntype in g.ntypes:
nodes = dst_nodes.get(ntype, None)
if nodes is not None:
dst_nodes_nd.append(F.to_dgl_nd(nodes))
else:
dst_nodes_nd.append(nd.NULL[g._idtype_str])
dst_node_ids = [
utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor()
for ntype in g.ntypes]
dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids]
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src)
g._graph, dst_node_ids_nd, include_dst_in_src)
# The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = (g.ntypes, g.ntypes)
......@@ -1735,7 +1728,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
assert new_graph.is_unibipartite # sanity check
src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_nd]
dst_node_ids = [F.from_dgl_nd(dst) for dst in dst_nodes_nd]
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]
node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment