Unverified Commit 9e7fbf95 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] JumpingKnowledge (#3512)

* Update

* Fix
parent 3aef4677
......@@ -310,6 +310,13 @@ SegmentedKNNGraph
:members:
:show-inheritance:
JumpingKnowledge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.utils.JumpingKnowledge
:members: forward, reset_parameters
:show-inheritance:
NodeEmbedding Module
----------------------------------------
......
......@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.conv import GraphConv
from dgl.nn import GraphConv, JumpingKnowledge
class JKNet(nn.Module):
def __init__(self,
......@@ -13,7 +13,7 @@ class JKNet(nn.Module):
mode='cat',
dropout=0.):
super(JKNet, self).__init__()
self.mode = mode
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
......@@ -21,11 +21,13 @@ class JKNet(nn.Module):
for _ in range(num_layers):
self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))
if self.mode == 'lstm':
self.jump = JumpingKnowledge(mode, hid_dim, num_layers)
else:
self.jump = JumpingKnowledge(mode)
if self.mode == 'cat':
hid_dim = hid_dim * (num_layers + 1)
elif self.mode == 'lstm':
self.lstm = nn.LSTM(hid_dim, (num_layers * hid_dim) // 2, bidirectional=True, batch_first=True)
self.attn = nn.Linear(2 * ((num_layers * hid_dim) // 2), 1)
self.output = nn.Linear(hid_dim, out_dim)
self.reset_params()
......@@ -34,29 +36,15 @@ class JKNet(nn.Module):
self.output.reset_parameters()
for layers in self.layers:
layers.reset_parameters()
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.attn.reset_parameters()
self.jump.reset_parameters()
def forward(self, g, feats):
feat_lst = []
for layer in self.layers:
feats = self.dropout(layer(g, feats))
feat_lst.append(feats)
if self.mode == 'cat':
out = torch.cat(feat_lst, dim=-1)
elif self.mode == 'max':
out = torch.stack(feat_lst, dim=-1).max(dim=-1)[0]
else:
# lstm
x = torch.stack(feat_lst, dim=1)
alpha, _ = self.lstm(x)
alpha = self.attn(alpha).squeeze(-1)
alpha = torch.softmax(alpha, dim=-1).unsqueeze(-1)
out = (x * alpha).sum(dim=1)
g.ndata['h'] = out
g.ndata['h'] = self.jump(feat_lst)
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
return self.output(g.ndata['h'])
......@@ -5,5 +5,5 @@ from .glob import *
from .softmax import *
from .factory import *
from .hetero import *
from .utils import Sequential, WeightBasis
from .utils import Sequential, WeightBasis, JumpingKnowledge
from .sparse_emb import NodeEmbedding
......@@ -282,3 +282,124 @@ class WeightBasis(nn.Module):
# generate all weights from bases
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)
class JumpingKnowledge(nn.Module):
r"""
Description
-----------
The Jumping Knowledge aggregation module introduced in `Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It
aggregates the output representations of multiple GNN layers with
**concatenation**
.. math::
h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}
or **max pooling**
.. math::
\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)
or **LSTM**
.. math::
\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}
with attention scores :math:`\alpha_i^{(t)}` obtained from a BiLSTM
Parameters
----------
mode : str
The aggregation to apply. It can be 'cat', 'max', or 'lstm',
corresponding to the equations above in order.
in_feats : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The output representation size of a single GNN layer. Note that
all GNN layers need to have the same output representation size.
num_layers : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The number of GNN layers for output aggregation.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])
>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])
>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
"""
def __init__(self, mode='cat', in_feats=None, num_layers=None):
super(JumpingKnowledge, self).__init__()
assert mode in ['cat', 'max', 'lstm'], \
"Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
self.mode = mode
if mode == 'lstm':
assert in_feats is not None, 'in_feats is required for lstm mode'
assert num_layers is not None, 'num_layers is required for lstm mode'
hidden_size = (num_layers * in_feats) // 2
self.lstm = nn.LSTM(in_feats, hidden_size, bidirectional=True, batch_first=True)
self.att = nn.Linear(2 * hidden_size, 1)
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode.
"""
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.att.reset_parameters()
def forward(self, feat_list):
r"""
Description
-----------
Aggregate output representations across multiple GNN layers.
Parameters
----------
feat_list : list[Tensor]
feat_list[i] is the output representations of a GNN layer.
Returns
-------
Tensor
The aggregated representations.
"""
if self.mode == 'cat':
return th.cat(feat_list, dim=-1)
elif self.mode == 'max':
return th.stack(feat_list, dim=-1).max(dim=-1)[0]
else:
# LSTM
stacked_feat_list = th.stack(feat_list, dim=1) # (N, num_layers, in_feats)
alpha, _ = self.lstm(stacked_feat_list)
alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)
......@@ -1229,6 +1229,26 @@ def test_gnnexplainer(g, idtype, out_dim):
explainer = nn.GNNExplainer(model, num_hops=1)
feat_mask, edge_mask = explainer.explain_graph(g, feat)
def test_jumping_knowledge():
ctx = F.ctx()
num_layers = 2
num_nodes = 3
num_feats = 4
feat_list = [th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)]
model = nn.JumpingKnowledge('cat').to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_layers * num_feats)
model = nn.JumpingKnowledge('max').to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_feats)
model = nn.JumpingKnowledge('lstm', num_feats, num_layers).to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_feats)
if __name__ == '__main__':
test_graph_conv()
test_graph_conv_e_weight()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment