Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5f44a4ef
Unverified
Commit
5f44a4ef
authored
Sep 10, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Sep 10, 2020
Browse files
Fix #1453 again (#2169)
* Fix reincarnation of #1453 * fix
parent
567c5acf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
13 deletions
+5
-13
python/dgl/transform.py
python/dgl/transform.py
+5
-13
No files found.
python/dgl/transform.py
View file @
5f44a4ef
...
@@ -1713,21 +1713,14 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
...
@@ -1713,21 +1713,14 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
raise
ValueError
(
raise
ValueError
(
'Graph has more than one node type; please specify a dict for dst_nodes.'
)
'Graph has more than one node type; please specify a dict for dst_nodes.'
)
dst_nodes
=
{
g
.
ntypes
[
0
]:
dst_nodes
}
dst_nodes
=
{
g
.
ntypes
[
0
]:
dst_nodes
}
dst_nodes
=
{
ntype
:
utils
.
toindex
(
nodes
,
g
.
_idtype_str
).
tousertensor
()
for
ntype
,
nodes
in
dst_nodes
.
items
()}
# dst_nodes is now a dict
dst_node_ids
=
[
dst_nodes_nd
=
[]
utils
.
toindex
(
dst_nodes
.
get
(
ntype
,
[]),
g
.
_idtype_str
).
tousertensor
()
for
ntype
in
g
.
ntypes
:
for
ntype
in
g
.
ntypes
]
nodes
=
dst_nodes
.
get
(
ntype
,
None
)
dst_node_ids_nd
=
[
F
.
to_dgl_nd
(
nodes
)
for
nodes
in
dst_node_ids
]
if
nodes
is
not
None
:
dst_nodes_nd
.
append
(
F
.
to_dgl_nd
(
nodes
))
else
:
dst_nodes_nd
.
append
(
nd
.
NULL
[
g
.
_idtype_str
])
new_graph_index
,
src_nodes_nd
,
induced_edges_nd
=
_CAPI_DGLToBlock
(
new_graph_index
,
src_nodes_nd
,
induced_edges_nd
=
_CAPI_DGLToBlock
(
g
.
_graph
,
dst_nodes_nd
,
include_dst_in_src
)
g
.
_graph
,
dst_node
_id
s_nd
,
include_dst_in_src
)
# The new graph duplicates the original node types to SRC and DST sets.
# The new graph duplicates the original node types to SRC and DST sets.
new_ntypes
=
(
g
.
ntypes
,
g
.
ntypes
)
new_ntypes
=
(
g
.
ntypes
,
g
.
ntypes
)
...
@@ -1735,7 +1728,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
...
@@ -1735,7 +1728,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
assert
new_graph
.
is_unibipartite
# sanity check
assert
new_graph
.
is_unibipartite
# sanity check
src_node_ids
=
[
F
.
from_dgl_nd
(
src
)
for
src
in
src_nodes_nd
]
src_node_ids
=
[
F
.
from_dgl_nd
(
src
)
for
src
in
src_nodes_nd
]
dst_node_ids
=
[
F
.
from_dgl_nd
(
dst
)
for
dst
in
dst_nodes_nd
]
edge_ids
=
[
F
.
from_dgl_nd
(
eid
)
for
eid
in
induced_edges_nd
]
edge_ids
=
[
F
.
from_dgl_nd
(
eid
)
for
eid
in
induced_edges_nd
]
node_frames
=
utils
.
extract_node_subframes_for_block
(
g
,
src_node_ids
,
dst_node_ids
)
node_frames
=
utils
.
extract_node_subframes_for_block
(
g
,
src_node_ids
,
dst_node_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment