Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2570d412
Unverified
Commit
2570d412
authored
Jun 12, 2021
by
Jinjing Zhou
Committed by
GitHub
Jun 12, 2021
Browse files
[NN] Add fast path for GateGCNConv when it has only one edge type (#2994)
* fix gatedgcn * fix lint
parent
411bef54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
16 deletions
+48
-16
python/dgl/nn/pytorch/conv/gatedgraphconv.py
python/dgl/nn/pytorch/conv/gatedgraphconv.py
+30
-16
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+18
-0
No files found.
python/dgl/nn/pytorch/conv/gatedgraphconv.py
View file @
2570d412
...
@@ -61,6 +61,7 @@ class GatedGraphConv(nn.Module):
...
@@ -61,6 +61,7 @@ class GatedGraphConv(nn.Module):
[ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030,
[ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030,
0.1342, 0.0425]], grad_fn=<AddBackward0>)
0.1342, 0.0425]], grad_fn=<AddBackward0>)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
in_feats
,
in_feats
,
out_feats
,
out_feats
,
...
@@ -110,7 +111,7 @@ class GatedGraphConv(nn.Module):
...
@@ -110,7 +111,7 @@ class GatedGraphConv(nn.Module):
"""
"""
self
.
_allow_zero_in_degree
=
set_value
self
.
_allow_zero_in_degree
=
set_value
def
forward
(
self
,
graph
,
feat
,
etypes
):
def
forward
(
self
,
graph
,
feat
,
etypes
=
None
):
"""
"""
Description
Description
...
@@ -125,9 +126,10 @@ class GatedGraphConv(nn.Module):
...
@@ -125,9 +126,10 @@ class GatedGraphConv(nn.Module):
The input feature of shape :math:`(N, D_{in})` where :math:`N`
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
input feature size.
etypes : torch.LongTensor
etypes : torch.LongTensor
, or None
The edge type tensor of shape :math:`(E,)` where :math:`E` is
The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph.
the number of edges of the graph. When there's only one edge type,
this argument can be skipped
Returns
Returns
-------
-------
...
@@ -139,21 +141,33 @@ class GatedGraphConv(nn.Module):
...
@@ -139,21 +141,33 @@ class GatedGraphConv(nn.Module):
assert
graph
.
is_homogeneous
,
\
assert
graph
.
is_homogeneous
,
\
"not a homogeneous graph; convert it with to_homogeneous "
\
"not a homogeneous graph; convert it with to_homogeneous "
\
"and pass in the edge type as argument"
"and pass in the edge type as argument"
assert
etypes
.
min
()
>=
0
and
etypes
.
max
()
<
self
.
_n_etypes
,
\
if
self
.
_n_etypes
!=
1
:
"edge type indices out of range [0, {})"
.
format
(
self
.
_n_etypes
)
assert
etypes
.
min
()
>=
0
and
etypes
.
max
()
<
self
.
_n_etypes
,
\
zero_pad
=
feat
.
new_zeros
((
feat
.
shape
[
0
],
self
.
_out_feats
-
feat
.
shape
[
1
]))
"edge type indices out of range [0, {})"
.
format
(
self
.
_n_etypes
)
zero_pad
=
feat
.
new_zeros
(
(
feat
.
shape
[
0
],
self
.
_out_feats
-
feat
.
shape
[
1
]))
feat
=
th
.
cat
([
feat
,
zero_pad
],
-
1
)
feat
=
th
.
cat
([
feat
,
zero_pad
],
-
1
)
for
_
in
range
(
self
.
_n_steps
):
for
_
in
range
(
self
.
_n_steps
):
graph
.
ndata
[
'h'
]
=
feat
if
self
.
_n_etypes
==
1
and
etypes
is
None
:
for
i
in
range
(
self
.
_n_etypes
):
# Fast path when graph has only one edge type
eids
=
th
.
nonzero
(
etypes
==
i
,
as_tuple
=
False
).
view
(
-
1
).
type
(
graph
.
idtype
)
graph
.
ndata
[
'h'
]
=
self
.
linears
[
0
](
feat
)
if
len
(
eids
)
>
0
:
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'a'
))
graph
.
apply_edges
(
a
=
graph
.
ndata
.
pop
(
'a'
)
# (N, D)
lambda
edges
:
{
'W_e*h'
:
self
.
linears
[
i
](
edges
.
src
[
'h'
])},
else
:
eids
graph
.
ndata
[
'h'
]
=
feat
)
for
i
in
range
(
self
.
_n_etypes
):
graph
.
update_all
(
fn
.
copy_e
(
'W_e*h'
,
'm'
),
fn
.
sum
(
'm'
,
'a'
))
eids
=
th
.
nonzero
(
a
=
graph
.
ndata
.
pop
(
'a'
)
# (N, D)
etypes
==
i
,
as_tuple
=
False
).
view
(
-
1
).
type
(
graph
.
idtype
)
if
len
(
eids
)
>
0
:
graph
.
apply_edges
(
lambda
edges
:
{
'W_e*h'
:
self
.
linears
[
i
](
edges
.
src
[
'h'
])},
eids
)
graph
.
update_all
(
fn
.
copy_e
(
'W_e*h'
,
'm'
),
fn
.
sum
(
'm'
,
'a'
))
a
=
graph
.
ndata
.
pop
(
'a'
)
# (N, D)
feat
=
self
.
gru
(
a
,
feat
)
feat
=
self
.
gru
(
a
,
feat
)
return
feat
return
feat
tests/pytorch/test_nn.py
View file @
2570d412
...
@@ -725,6 +725,23 @@ def test_gated_graph_conv(g, idtype):
...
@@ -725,6 +725,23 @@ def test_gated_graph_conv(g, idtype):
# current we only do shape check
# current we only do shape check
assert
h
.
shape
[
-
1
]
==
10
assert
h
.
shape
[
-
1
]
==
10
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
],
exclude
=
[
'zero-degree'
]))
def
test_gated_graph_conv_one_etype
(
g
,
idtype
):
ctx
=
F
.
ctx
()
g
=
g
.
astype
(
idtype
).
to
(
ctx
)
ggconv
=
nn
.
GatedGraphConv
(
5
,
10
,
5
,
1
)
etypes
=
th
.
zeros
(
g
.
number_of_edges
())
feat
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
ggconv
=
ggconv
.
to
(
ctx
)
etypes
=
etypes
.
to
(
ctx
)
h
=
ggconv
(
g
,
feat
,
etypes
)
h2
=
ggconv
(
g
,
feat
)
# current we only do shape check
assert
F
.
allclose
(
h
,
h2
)
assert
h
.
shape
[
-
1
]
==
10
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
],
exclude
=
[
'zero-degree'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
],
exclude
=
[
'zero-degree'
]))
def
test_nn_conv
(
g
,
idtype
):
def
test_nn_conv
(
g
,
idtype
):
...
@@ -1113,6 +1130,7 @@ if __name__ == '__main__':
...
@@ -1113,6 +1130,7 @@ if __name__ == '__main__':
test_gin_conv
()
test_gin_conv
()
test_agnn_conv
()
test_agnn_conv
()
test_gated_graph_conv
()
test_gated_graph_conv
()
test_gated_graph_conv_one_etype
()
test_nn_conv
()
test_nn_conv
()
test_gmm_conv
()
test_gmm_conv
()
test_dotgat_conv
()
test_dotgat_conv
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment