Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9e7fbf95
Unverified
Commit
9e7fbf95
authored
Nov 19, 2021
by
Mufei Li
Committed by
GitHub
Nov 19, 2021
Browse files
[NN] JumpingKnowledge (#3512)
* Update * Fix
parent
3aef4677
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
159 additions
and
23 deletions
+159
-23
docs/source/api/python/nn.pytorch.rst
docs/source/api/python/nn.pytorch.rst
+7
-0
examples/pytorch/jknet/model.py
examples/pytorch/jknet/model.py
+10
-22
python/dgl/nn/pytorch/__init__.py
python/dgl/nn/pytorch/__init__.py
+1
-1
python/dgl/nn/pytorch/utils.py
python/dgl/nn/pytorch/utils.py
+121
-0
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+20
-0
No files found.
docs/source/api/python/nn.pytorch.rst
View file @
9e7fbf95
...
...
@@ -310,6 +310,13 @@ SegmentedKNNGraph
:members:
:show-inheritance:
JumpingKnowledge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.utils.JumpingKnowledge
:members: forward, reset_parameters
:show-inheritance:
NodeEmbedding Module
----------------------------------------
...
...
examples/pytorch/jknet/model.py
View file @
9e7fbf95
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl.nn
.pytorch.conv
import
GraphConv
from
dgl.nn
import
GraphConv
,
JumpingKnowledge
class
JKNet
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -21,11 +21,13 @@ class JKNet(nn.Module):
for
_
in
range
(
num_layers
):
self
.
layers
.
append
(
GraphConv
(
hid_dim
,
hid_dim
,
activation
=
F
.
relu
))
if
self
.
mode
==
'lstm'
:
self
.
jump
=
JumpingKnowledge
(
mode
,
hid_dim
,
num_layers
)
else
:
self
.
jump
=
JumpingKnowledge
(
mode
)
if
self
.
mode
==
'cat'
:
hid_dim
=
hid_dim
*
(
num_layers
+
1
)
elif
self
.
mode
==
'lstm'
:
self
.
lstm
=
nn
.
LSTM
(
hid_dim
,
(
num_layers
*
hid_dim
)
//
2
,
bidirectional
=
True
,
batch_first
=
True
)
self
.
attn
=
nn
.
Linear
(
2
*
((
num_layers
*
hid_dim
)
//
2
),
1
)
self
.
output
=
nn
.
Linear
(
hid_dim
,
out_dim
)
self
.
reset_params
()
...
...
@@ -34,9 +36,7 @@ class JKNet(nn.Module):
self
.
output
.
reset_parameters
()
for
layers
in
self
.
layers
:
layers
.
reset_parameters
()
if
self
.
mode
==
'lstm'
:
self
.
lstm
.
reset_parameters
()
self
.
attn
.
reset_parameters
()
self
.
jump
.
reset_parameters
()
def
forward
(
self
,
g
,
feats
):
feat_lst
=
[]
...
...
@@ -44,19 +44,7 @@ class JKNet(nn.Module):
feats
=
self
.
dropout
(
layer
(
g
,
feats
))
feat_lst
.
append
(
feats
)
if
self
.
mode
==
'cat'
:
out
=
torch
.
cat
(
feat_lst
,
dim
=-
1
)
elif
self
.
mode
==
'max'
:
out
=
torch
.
stack
(
feat_lst
,
dim
=-
1
).
max
(
dim
=-
1
)[
0
]
else
:
# lstm
x
=
torch
.
stack
(
feat_lst
,
dim
=
1
)
alpha
,
_
=
self
.
lstm
(
x
)
alpha
=
self
.
attn
(
alpha
).
squeeze
(
-
1
)
alpha
=
torch
.
softmax
(
alpha
,
dim
=-
1
).
unsqueeze
(
-
1
)
out
=
(
x
*
alpha
).
sum
(
dim
=
1
)
g
.
ndata
[
'h'
]
=
out
g
.
ndata
[
'h'
]
=
self
.
jump
(
feat_lst
)
g
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'h'
))
return
self
.
output
(
g
.
ndata
[
'h'
])
python/dgl/nn/pytorch/__init__.py
View file @
9e7fbf95
...
...
@@ -5,5 +5,5 @@ from .glob import *
from
.softmax
import
*
from
.factory
import
*
from
.hetero
import
*
from
.utils
import
Sequential
,
WeightBasis
from
.utils
import
Sequential
,
WeightBasis
,
JumpingKnowledge
from
.sparse_emb
import
NodeEmbedding
python/dgl/nn/pytorch/utils.py
View file @
9e7fbf95
...
...
@@ -282,3 +282,124 @@ class WeightBasis(nn.Module):
# generate all weights from bases
weight
=
th
.
matmul
(
self
.
w_comp
,
self
.
weight
.
view
(
self
.
num_bases
,
-
1
))
return
weight
.
view
(
self
.
num_outputs
,
*
self
.
shape
)
class
JumpingKnowledge
(
nn
.
Module
):
r
"""
Description
-----------
The Jumping Knowledge aggregation module introduced in `Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It
aggregates the output representations of multiple GNN layers with
**concatenation**
.. math::
h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}
or **max pooling**
.. math::
\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)
or **LSTM**
.. math::
\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}
with attention scores :math:`\alpha_i^{(t)}` obtained from a BiLSTM
Parameters
----------
mode : str
The aggregation to apply. It can be 'cat', 'max', or 'lstm',
corresponding to the equations above in order.
in_feats : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The output representation size of a single GNN layer. Note that
all GNN layers need to have the same output representation size.
num_layers : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The number of GNN layers for output aggregation.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])
>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])
>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
"""
def
__init__
(
self
,
mode
=
'cat'
,
in_feats
=
None
,
num_layers
=
None
):
super
(
JumpingKnowledge
,
self
).
__init__
()
assert
mode
in
[
'cat'
,
'max'
,
'lstm'
],
\
"Expect mode to be 'cat', or 'max' or 'lstm', got {}"
.
format
(
mode
)
self
.
mode
=
mode
if
mode
==
'lstm'
:
assert
in_feats
is
not
None
,
'in_feats is required for lstm mode'
assert
num_layers
is
not
None
,
'num_layers is required for lstm mode'
hidden_size
=
(
num_layers
*
in_feats
)
//
2
self
.
lstm
=
nn
.
LSTM
(
in_feats
,
hidden_size
,
bidirectional
=
True
,
batch_first
=
True
)
self
.
att
=
nn
.
Linear
(
2
*
hidden_size
,
1
)
def
reset_parameters
(
self
):
r
"""
Description
-----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode.
"""
if
self
.
mode
==
'lstm'
:
self
.
lstm
.
reset_parameters
()
self
.
att
.
reset_parameters
()
def
forward
(
self
,
feat_list
):
r
"""
Description
-----------
Aggregate output representations across multiple GNN layers.
Parameters
----------
feat_list : list[Tensor]
feat_list[i] is the output representations of a GNN layer.
Returns
-------
Tensor
The aggregated representations.
"""
if
self
.
mode
==
'cat'
:
return
th
.
cat
(
feat_list
,
dim
=-
1
)
elif
self
.
mode
==
'max'
:
return
th
.
stack
(
feat_list
,
dim
=-
1
).
max
(
dim
=-
1
)[
0
]
else
:
# LSTM
stacked_feat_list
=
th
.
stack
(
feat_list
,
dim
=
1
)
# (N, num_layers, in_feats)
alpha
,
_
=
self
.
lstm
(
stacked_feat_list
)
alpha
=
self
.
att
(
alpha
).
squeeze
(
-
1
)
# (N, num_layers)
alpha
=
th
.
softmax
(
alpha
,
dim
=-
1
)
return
(
stacked_feat_list
*
alpha
.
unsqueeze
(
-
1
)).
sum
(
dim
=
1
)
tests/pytorch/test_nn.py
View file @
9e7fbf95
...
...
@@ -1229,6 +1229,26 @@ def test_gnnexplainer(g, idtype, out_dim):
explainer
=
nn
.
GNNExplainer
(
model
,
num_hops
=
1
)
feat_mask
,
edge_mask
=
explainer
.
explain_graph
(
g
,
feat
)
def
test_jumping_knowledge
():
ctx
=
F
.
ctx
()
num_layers
=
2
num_nodes
=
3
num_feats
=
4
feat_list
=
[
th
.
randn
((
num_nodes
,
num_feats
)).
to
(
ctx
)
for
_
in
range
(
num_layers
)]
model
=
nn
.
JumpingKnowledge
(
'cat'
).
to
(
ctx
)
model
.
reset_parameters
()
assert
model
(
feat_list
).
shape
==
(
num_nodes
,
num_layers
*
num_feats
)
model
=
nn
.
JumpingKnowledge
(
'max'
).
to
(
ctx
)
model
.
reset_parameters
()
assert
model
(
feat_list
).
shape
==
(
num_nodes
,
num_feats
)
model
=
nn
.
JumpingKnowledge
(
'lstm'
,
num_feats
,
num_layers
).
to
(
ctx
)
model
.
reset_parameters
()
assert
model
(
feat_list
).
shape
==
(
num_nodes
,
num_feats
)
if
__name__
==
'__main__'
:
test_graph_conv
()
test_graph_conv_e_weight
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment