Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8798872f
Unverified
Commit
8798872f
authored
Oct 14, 2021
by
Rhett Ying
Committed by
GitHub
Oct 14, 2021
Browse files
[Bug] Do not skip graphconv even no edge exists (#3416)
parent
7c7b60be
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
12 deletions
+42
-12
python/dgl/nn/mxnet/hetero.py
python/dgl/nn/mxnet/hetero.py
+0
-4
python/dgl/nn/pytorch/hetero.py
python/dgl/nn/pytorch/hetero.py
+0
-4
python/dgl/nn/tensorflow/hetero.py
python/dgl/nn/tensorflow/hetero.py
+0
-4
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+14
-0
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+14
-0
tests/tensorflow/test_nn.py
tests/tensorflow/test_nn.py
+14
-0
No files found.
python/dgl/nn/mxnet/hetero.py
View file @
8798872f
...
...
@@ -165,8 +165,6 @@ class HeteroGraphConv(nn.Block):
src_inputs
,
dst_inputs
=
inputs
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
src_inputs
or
dtype
not
in
dst_inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
@@ -178,8 +176,6 @@ class HeteroGraphConv(nn.Block):
else
:
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
python/dgl/nn/pytorch/hetero.py
View file @
8798872f
...
...
@@ -169,8 +169,6 @@ class HeteroGraphConv(nn.Module):
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
src_inputs
or
dtype
not
in
dst_inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
@@ -182,8 +180,6 @@ class HeteroGraphConv(nn.Module):
else
:
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
python/dgl/nn/tensorflow/hetero.py
View file @
8798872f
...
...
@@ -169,8 +169,6 @@ class HeteroGraphConv(layers.Layer):
src_inputs
,
dst_inputs
=
inputs
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
src_inputs
or
dtype
not
in
dst_inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
@@ -182,8 +180,6 @@ class HeteroGraphConv(layers.Layer):
else
:
for
stype
,
etype
,
dtype
in
g
.
canonical_etypes
:
rel_graph
=
g
[
stype
,
etype
,
dtype
]
if
rel_graph
.
number_of_edges
()
==
0
:
continue
if
stype
not
in
inputs
:
continue
dstdata
=
self
.
mods
[
etype
](
...
...
tests/mxnet/test_nn.py
View file @
8798872f
...
...
@@ -788,6 +788,19 @@ def test_hetero_conv(agg, idtype):
assert
mod2
.
carg1
==
1
assert
mod3
.
carg1
==
0
#conv on graph without any edges
for
etype
in
g
.
etypes
:
g
=
dgl
.
remove_edges
(
g
,
g
.
edges
(
form
=
'eid'
,
etype
=
etype
),
etype
=
etype
)
assert
g
.
num_edges
()
==
0
h
=
conv
(
g
,
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
})
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
block
=
dgl
.
to_block
(
g
.
to
(
F
.
cpu
()),
{
'user'
:
[
0
,
1
,
2
,
3
],
'game'
:
[
0
,
1
,
2
,
3
],
'store'
:
[]}).
to
(
F
.
ctx
())
h
=
conv
(
block
,
({
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
},
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
[
0
:
0
]}))
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
if
__name__
==
'__main__'
:
test_graph_conv
()
test_gat_conv
()
...
...
@@ -809,3 +822,4 @@ if __name__ == '__main__':
test_simple_pool
()
test_rgcn
()
test_sequential
()
test_hetero_conv
()
tests/pytorch/test_nn.py
View file @
8798872f
...
...
@@ -1112,6 +1112,19 @@ def test_hetero_conv(agg, idtype):
assert
mod3
.
carg1
==
0
assert
mod3
.
carg2
==
1
#conv on graph without any edges
for
etype
in
g
.
etypes
:
g
=
dgl
.
remove_edges
(
g
,
g
.
edges
(
form
=
'eid'
,
etype
=
etype
),
etype
=
etype
)
assert
g
.
num_edges
()
==
0
h
=
conv
(
g
,
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
})
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
block
=
dgl
.
to_block
(
g
.
to
(
F
.
cpu
()),
{
'user'
:
[
0
,
1
,
2
,
3
],
'game'
:
[
0
,
1
,
2
,
3
],
'store'
:
[]}).
to
(
F
.
ctx
())
h
=
conv
(
block
,
({
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
},
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
[
0
:
0
]}))
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
if
__name__
==
'__main__'
:
test_graph_conv
()
test_graph_conv_e_weight
()
...
...
@@ -1140,3 +1153,4 @@ if __name__ == '__main__':
test_sequential
()
test_atomic_conv
()
test_cf_conv
()
test_hetero_conv
()
tests/tensorflow/test_nn.py
View file @
8798872f
...
...
@@ -502,6 +502,19 @@ def test_hetero_conv(agg, idtype):
assert
mod3
.
carg1
==
0
assert
mod3
.
carg2
==
1
#conv on graph without any edges
for
etype
in
g
.
etypes
:
g
=
dgl
.
remove_edges
(
g
,
g
.
edges
(
form
=
'eid'
,
etype
=
etype
),
etype
=
etype
)
assert
g
.
num_edges
()
==
0
h
=
conv
(
g
,
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
})
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
block
=
dgl
.
to_block
(
g
.
to
(
F
.
cpu
()),
{
'user'
:
[
0
,
1
,
2
,
3
],
'game'
:
[
0
,
1
,
2
,
3
],
'store'
:
[]}).
to
(
F
.
ctx
())
h
=
conv
(
block
,
({
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
},
{
'user'
:
uf
,
'game'
:
gf
,
'store'
:
sf
[
0
:
0
]}))
assert
set
(
h
.
keys
())
==
{
'user'
,
'game'
}
@
pytest
.
mark
.
parametrize
(
'out_dim'
,
[
1
,
2
])
def
test_dense_cheb_conv
(
out_dim
):
...
...
@@ -549,3 +562,4 @@ if __name__ == '__main__':
# test_dense_sage_conv()
test_dense_cheb_conv
()
# test_sequential()
test_hetero_conv
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment