Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6ae93e5c
Unverified
Commit
6ae93e5c
authored
Dec 11, 2019
by
Quan (Andy) Gan
Committed by
GitHub
Dec 11, 2019
Browse files
[Bug] Fix #1088 (#1089)
* [Bug] Fix #1088 * fix * add comment
parent
48c7ec44
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
src/graph/heterograph.cc
src/graph/heterograph.cc
+6
-5
tests/compute/test_hetero_basics.py
tests/compute/test_hetero_basics.py
+9
-0
No files found.
src/graph/heterograph.cc
View file @
6ae93e5c
...
...
@@ -468,12 +468,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
HeteroGraphRef
hg
=
args
[
0
];
dgl_type_t
etype
=
args
[
1
];
if
(
hg
->
NumEdgeTypes
()
==
1
)
{
CHECK_EQ
(
etype
,
0
);
*
rv
=
hg
;
}
else
{
CHECK_LE
(
etype
,
hg
->
NumEdgeTypes
())
<<
"invalid edge type "
<<
etype
;
// Test if the heterograph is a unit graph. If so, return itself.
auto
bg
=
std
::
dynamic_pointer_cast
<
UnitGraph
>
(
hg
.
sptr
());
if
(
bg
!=
nullptr
)
*
rv
=
bg
;
else
*
rv
=
HeteroGraphRef
(
hg
->
GetRelationGraph
(
etype
));
}
});
DGL_REGISTER_GLOBAL
(
"heterograph_index._CAPI_DGLHeteroGetFlattenedGraph"
)
...
...
tests/compute/test_hetero_basics.py
View file @
6ae93e5c
...
...
@@ -762,6 +762,14 @@ def test_local_scope():
assert
F
.
allclose
(
g
.
edata
[
'w'
],
F
.
tensor
([[
1.
],
[
0.
]]))
foo
(
g
)
def
test_issue_1088
():
# This test ensures that message passing on a heterograph with one edge type
# would not crash (GitHub issue #1088).
import
dgl.function
as
fn
g
=
dgl
.
heterograph
({(
'U'
,
'E'
,
'V'
):
([
0
,
1
,
2
],
[
1
,
2
,
3
])})
g
.
nodes
[
'U'
].
data
[
'x'
]
=
F
.
randn
((
3
,
3
))
g
.
update_all
(
fn
.
copy_u
(
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'y'
))
if
__name__
==
'__main__'
:
test_nx_conversion
()
test_batch_setter_getter
()
...
...
@@ -781,3 +789,4 @@ if __name__ == '__main__':
test_group_apply_edges
()
test_local_var
()
test_local_scope
()
test_issue_1088
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment