Unverified Commit 650f6ee1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[NN] Add commonly used GNN models from examples to dgl.nn modules. (#748)

* gat

* upd

* upd sage

* upd

* upd

* upd

* upd

* upd

* add gmmconv

* upd ggnn

* upd

* upd

* upd

* upd

* add citation examples

* add README

* fix cheb

* improve doc

* formula

* upd

* trigger

* lint

* lint

* upd

* add test for transform

* add test

* check

* upd

* improve doc

* shape check

* upd

* densechebconv, currently not correct (?)

* fix cheb

* fix

* upd

* upd sgc-reddit

* upd

* trigger
parent 8079d986
"""MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
from mxnet import gluon, nd
from mxnet.gluon import nn
......
"""Torch modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from . import utils
from ... import function as fn
from ...batched_graph import broadcast_nodes
from ...transform import laplacian_lambda_max
from .softmax import edge_softmax
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv']
# pylint: disable=W0235
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
"""Return input"""
return x
__all__ = ['GraphConv', 'TGConv', 'RelGraphConv']
# pylint: enable=W0235
class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal.
......@@ -41,9 +63,9 @@ class GraphConv(nn.Module):
Parameters
----------
in_feats : int
Number of input features.
Input feature size.
out_feats : int
Number of output features.
Output feature size.
norm : bool, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``.
bias : bool, optional
......@@ -90,10 +112,10 @@ class GraphConv(nn.Module):
Notes
-----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
Parameters
----------
......@@ -109,7 +131,7 @@ class GraphConv(nn.Module):
"""
graph = graph.local_var()
if self._norm:
norm = th.pow(graph.in_degrees().float(), -0.5)
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
feat = feat * norm
......@@ -150,8 +172,125 @@ class GraphConv(nn.Module):
summary += ', activation={_activation}'
return summary.format(**self.__dict__)
class TGConv(nn.Module):
r"""Apply Topology Adaptive Graph Convolutional Network
class GATConv(nn.Module):
r"""Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
over an input signal.
.. math::
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}
where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
node :math:`j`:
.. math::
\alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})
e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h^{I} \| W h^{j}]\right)
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature, defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight, defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope.
residual : bool, optional
If True, use residual connection.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None):
super(GATConv, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if in_feats != out_feats:
self.res_fc = nn.Linear(in_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, graph):
r"""Compute graph attention network layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
"""
graph = graph.local_var()
h = self.feat_drop(feat)
feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.ndata.update({'ft': feat, 'el': el, 'er': er})
# compute edge attention
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.ndata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
return rst
class TAGConv(nn.Module):
r"""Topology Adaptive Graph Convolutional layer from paper `Topology
Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__.
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
......@@ -163,9 +302,9 @@ class TGConv(nn.Module):
Parameters
----------
in_feats : int
Number of input features.
Input feature size.
out_feats : int
Number of output features.
Output feature size.
k: int, optional
Number of hops :math: `k`. (default: 3)
bias: bool, optional
......@@ -185,7 +324,7 @@ class TGConv(nn.Module):
k=2,
bias=True,
activation=None):
super(TGConv, self).__init__()
super(TAGConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._k = k
......@@ -196,26 +335,29 @@ class TGConv(nn.Module):
def reset_parameters(self):
"""Reinitialize learnable parameters."""
self.lin.reset_parameters()
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, feat, graph):
r"""Compute graph convolution
r"""Compute topology adaptive graph convolution.
Parameters
----------
feat : torch.Tensor
The input feature
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
norm = th.pow(graph.in_degrees().float(), -0.5)
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
......@@ -380,7 +522,7 @@ class RelGraphConv(nn.Module):
return {'msg': msg}
def forward(self, g, x, etypes, norm=None):
"""Forward computation
""" Forward computation
Parameters
----------
......@@ -388,13 +530,13 @@ class RelGraphConv(nn.Module):
The graph.
x : torch.Tensor
Input node features. Could be either
- (|V|, D) dense tensor
- (|V|,) int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
etypes : torch.Tensor
Edge type tensor. Shape: (|E|,)
Edge type tensor. Shape: :math:`(|E|,)`
norm : torch.Tensor
Optional edge normalizer tensor. Shape: (|E|, 1)
Optional edge normalizer tensor. Shape: :math:`(|E|, 1)`
Returns
-------
......@@ -408,10 +550,8 @@ class RelGraphConv(nn.Module):
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight)
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
......@@ -421,5 +561,1117 @@ class RelGraphConv(nn.Module):
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr
class SAGEConv(nn.Module):
r"""GraphSAGE layer from paper `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
.. math::
h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right)
h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l})
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
feat_drop : float
Dropout rate on features, default: ``0``.
aggregator_type : str
Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization oto the updated node features.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(in_feats, in_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(in_feats, in_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(in_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_feats)),
m.new_zeros((1, batch_size, self._in_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, feat, graph):
r"""Compute GraphSAGE layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().float()
degs = degs.to(feat.device)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = F.relu(self.fc_pool(feat))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'lstm':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.ndata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
class GatedGraphConv(nn.Module):
r"""Gated Graph Convolution layer from paper `Gated Graph Sequence
Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__.
.. math::
h_{i}^{0} & = [ x_i \| \mathbf{0} ]
a_{i}^{t} & = \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}
h_{i}^{t+1} & = \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
n_steps : int
Number of recurrent steps.
n_etypes : int
Number of edge types.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
super(GatedGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._n_steps = n_steps
self.edge_embed = nn.Embedding(n_etypes, out_feats * out_feats)
self.gru = nn.GRUCell(out_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = init.calculate_gain('relu')
self.gru.reset_parameters()
init.xavier_normal_(self.edge_embed.weight, gain=gain)
def forward(self, feat, etypes, graph):
"""Compute Gated Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
etypes : torch.LongTensor
The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1]))
feat = th.cat([feat, zero_pad], -1)
# NOTE(zihao): there is still room to optimize, we may do kernel fusion
# for such operations in the future.
graph.edata['w'] = self.edge_embed(etypes).view(-1, self._out_feats, self._out_feats)
for _ in range(self._n_steps):
graph.ndata['h'] = feat.unsqueeze(-1) # (N, D, 1)
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'a'))
a = graph.ndata.pop('a').sum(dim=1) # (N, D)
feat = self.gru(a, feat)
return feat
class GMMConv(nn.Module):
r"""The Gaussian Mixture Model Convolution layer from `Geometric Deep
Learning on Graphs and Manifolds using Mixture Model CNNs
<http://openaccess.thecvf.com/content_cvpr_2017/papers/Monti_Geometric_Deep_Learning_CVPR_2017_paper.pdf>`__.
.. math::
h_i^{l+1} & = \mathrm{aggregate}\left(\left\{\frac{1}{K}
\sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)
w_k(u) & = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right)
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
dim : int
Dimensionality of pseudo-coordinte.
n_kernels : int
Number of kernels :math:`K`.
aggregator_type : str
Aggregator type (``sum``, ``mean``, ``max``).
residual : bool
If True, use residual connection inside this layer.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type,
residual=True,
bias=True):
super(GMMConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._dim = dim
self._n_kernels = n_kernels
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'mean':
self._reducer = fn.mean
elif aggregator_type == 'max':
self._reducer = fn.max
else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type))
self.mu = nn.Parameter(th.Tensor(n_kernels, dim))
self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim))
self.fc = nn.Linear(in_feats, n_kernels * out_feats, bias=False)
if residual:
if in_feats != out_feats:
self.res_fc = nn.Linear(in_feats, out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = init.calculate_gain('relu')
init.xavier_normal_(self.fc.weight, gain=gain)
if isinstance(self.res_fc, nn.Linear):
init.xavier_normal_(self.res_fc.weight, gain=gain)
init.normal_(self.mu.data, 0, 0.1)
init.normal_(self.inv_sigma.data, 1, 0.1)
if self.bias is not None:
init.zeros_(self.bias.data)
def forward(self, feat, pseudo, graph):
"""Compute Gaussian Mixture Model Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
pseudo : torch.Tensor
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}`
is the dimensionality of pseudo coordinate.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
graph.ndata['h'] = self.fc(feat).view(-1, self._n_kernels, self._out_feats)
E = graph.number_of_edges()
# compute gaussian weight
gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) -
self.mu.view(1, self._n_kernels, self._dim)) ** 2)
gaussian = gaussian * (self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2)
gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
graph.edata['w'] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
rst = graph.ndata['h'].sum(1)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat)
# bias
if self.bias is not None:
rst = rst + self.bias
return rst
class GINConv(nn.Module):
r"""Graph Isomorphism Network layer from paper `How Powerful are Graph
Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__.
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
Parameters
----------
apply_func : callable activation function/layer or None
If not None, apply this function to the updated node feature,
the :math:`f_\Theta` in the formula.
aggregator_type : str
Aggregator type to use (``sum``, ``max`` or ``mean``).
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter.
"""
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False):
super(GINConv, self).__init__()
self.apply_func = apply_func
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'max':
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
def forward(self, feat, graph):
r"""Compute Graph Isomorphism Network layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D`
could be any positive integer, :math:`N` is the number
of nodes. If ``apply_func`` is not None, :math:`D` should
fit the input dimensionality requirement of ``apply_func``.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output dimensionality of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
"""
graph = graph.local_var()
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat + graph.ndata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
class ChebConv(nn.Module):
r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional
Neural Networks on Graphs with Fast Localized Spectral Filtering
<https://arxiv.org/pdf/1606.09375.pdf>`__.
.. math::
h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}
Z^{0, l} &= H^{l}
Z^{1, l} &= \hat{L} \cdot H^{l}
Z^{k, l} &= 2 \cdot \hat{L} \cdot Z^{k-1, l} - Z^{k-2, l}
\hat{L} &= 2\left(I - \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2}\right)/\lambda_{max} - I
Parameters
----------
in_feats: int
Number of input features.
out_feats: int
Number of output features.
k : int
Chebyshev filter size.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
k,
bias=True):
super(ChebConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.fc = nn.ModuleList([
nn.Linear(in_feats, out_feats, bias=False) for _ in range(k)
])
self._k = k
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.bias is not None:
init.zeros_(self.bias)
for module in self.fc.modules():
if isinstance(module, nn.Linear):
init.xavier_normal_(module.weight, init.calculate_gain('relu'))
if module.bias is not None:
init.zeros_(module.bias)
def forward(self, feat, graph, lambda_max=None):
r"""Compute ChebNet layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
with graph.local_scope():
norm = th.pow(
graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None:
lambda_max = laplacian_lambda_max(graph)
lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() < 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max)
# T0(X)
Tx_0 = feat
rst = self.fc[0](Tx_0)
# T1(X)
if self._k > 1:
graph.ndata['h'] = Tx_0 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h = graph.ndata.pop('h') * norm
# Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I
# = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1)
rst = rst + self.fc[1](Tx_1)
# Ti(x), i = 2...k
for i in range(2, self._k):
graph.ndata['h'] = Tx_1 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h = graph.ndata.pop('h') * norm
# Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
# = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +
# (4 / lambda_max - 2) Tx_(k-1) -
# Tx_(k-2)
Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0
rst = rst + self.fc[i](Tx_2)
Tx_1, Tx_0 = Tx_2, Tx_1
# add bias
if self.bias is not None:
rst = rst + self.bias
return rst
class SGConv(nn.Module):
r"""Simplifying Graph Convolution layer from paper `Simplifying Graph
Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__.
.. math::
H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l}
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k : int
Number of hops :math:`K`. Defaults:``1``.
cached : bool
If True, the module would cache
.. math::
(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta
at the first forward call. This parameter should only be set to
``True`` in Transductive Learning setting.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization oto the updated node features.
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None):
super(SGConv, self).__init__()
self.fc = nn.Linear(in_feats, out_feats, bias=bias)
self._cached = cached
self._cached_h = None
self._k = k
self.norm = norm
def forward(self, feat, graph):
r"""Compute Simplifying Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
Notes
-----
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results.
"""
graph = graph.local_var()
if self._cached_h is not None:
feat = self._cached_h
else:
# compute normalization
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm[th.isinf(norm)] = 0
norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A D) X
for _ in range(self._k):
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
feat = feat * norm
if self.norm is not None:
feat = self.norm(feat)
# cache feature
if self._cached:
self._cached_h = feat
return self.fc(feat)
class NNConv(nn.Module):
r"""Graph Convolution layer introduced in `Neural Message Passing
for Quantum Chemistry <https://arxiv.org/pdf/1704.01212.pdf>`__.
.. math::
h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{
f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right)
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
edge_func : callable activation function/layer
Maps each edge feature to a vector of shape
``(in_feats * out_feats)`` as weight to compute
messages.
Also is the :math:`f_\Theta` in the formula.
aggregator_type : str
Aggregator type to use (``sum``, ``mean`` or ``max``).
residual : bool, optional
If True, use residual connection. Default: ``False``.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
edge_func,
aggregator_type,
residual=False,
bias=True):
super(NNConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.edge_nn = edge_func
if aggregator_type == 'sum':
self.reducer = fn.sum
elif aggregator_type == 'mean':
self.reducer = fn.mean
elif aggregator_type == 'max':
self.reducer = fn.max
else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type))
self._aggre_type = aggregator_type
if residual:
if in_feats != out_feats:
self.res_fc = nn.Linear(in_feats, out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = init.calculate_gain('relu')
if self.bias is not None:
nn.init.zeros_(self.bias)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, efeat, graph):
r"""Compute MPNN Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
efeat : torch.Tensor
The edge feature of shape :math:`(N, *)`, should fit the input
shape requirement of ``edge_nn``.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
# (n, d_in, 1)
graph.ndata['h'] = feat.unsqueeze(-1)
# (n, d_in, d_out)
graph.edata['w'] = self.edge_nn(efeat).view(-1, self._in_feats, self._out_feats)
# (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
rst = graph.ndata.pop('neigh').sum(dim=1) # (n, d_out)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat)
# bias
if self.bias is not None:
rst = rst + self.bias
return rst
class APPNPConv(nn.Module):
r"""Approximate Personalized Propagation of Neural Predictions
layer from paper `Predict then Propagate: Graph Neural Networks
meet Personalized PageRank <https://arxiv.org/pdf/1810.05997.pdf>`__.
.. math::
H^{0} & = X
H^{t+1} & = (1-\alpha)\left(\hat{D}^{-1/2}
\hat{A} \hat{D}^{-1/2} H^{t} + \alpha H^{0}\right)
Parameters
----------
k : int
Number of iterations :math:`K`.
alpha : float
The teleport probability :math:`\alpha`.
edge_drop : float, optional
Dropout rate on edges that controls the
messages received by each node. Default: ``0``.
"""
def __init__(self,
k,
alpha,
edge_drop=0.):
super(APPNPConv, self).__init__()
self._k = k
self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity()
def forward(self, feat, graph):
r"""Compute APPNP layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape.
"""
graph = graph.local_var()
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
norm = norm.unsqueeze(-1).to(feat.device)
feat_0 = feat
for _ in range(self._k):
# normalization by src
feat = feat * norm
graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop(
th.ones(graph.number_of_edges(), 1).to(feat.device))
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
# normalization by dst
feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat
class AGNNConv(nn.Module):
r"""Attention-based Graph Neural Network layer from paper `Attention-based
Graph Neural Network for Semi-Supervised Learning
<https://arxiv.org/abs/1803.03735>`__.
.. math::
H^{l+1} = P H^{l}
where :math:`P` is computed as:
.. math::
P_{ij} = \mathrm{softmax}_i ( \beta \cdot \cos(h_i^l, h_j^l))
Parameters
----------
init_beta : float, optional
The :math:`\beta` in the formula.
learn_beta : bool, optional
If True, :math:`\beta` will be learnable parameter.
"""
def __init__(self,
init_beta=1.,
learn_beta=True):
super(AGNNConv, self).__init__()
if learn_beta:
self.beta = nn.Parameter(th.Tensor([init_beta]))
else:
self.register_buffer('beta', th.Tensor([init_beta]))
def forward(self, feat, graph):
r"""Compute AGNN layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape.
"""
graph = graph.local_var()
graph.ndata['h'] = feat
graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1)
# compute cosine distance
graph.apply_edges(fn.u_mul_v('norm_h', 'norm_h', 'cos'))
cos = graph.edata.pop('cos').sum(-1)
e = self.beta * cos
graph.edata['p'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h')
class DenseGraphConv(nn.Module):
"""Graph Convolutional Network layer where the graph structure
is given by an adjacency matrix.
We recommend user to use this module when inducing graph convolution
on dense graphs / k-hop graphs.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
norm : bool
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
See also
--------
GraphConv
"""
def __init__(self,
in_feats,
out_feats,
norm=True,
bias=True,
activation=None):
super(DenseGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
self._activation = activation
def reset_parameters(self):
"""Reinitialize learnable parameters."""
init.xavier_uniform_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, feat, adj):
r"""Compute (Dense) Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
adj = adj.float().to(feat.device)
if self._norm:
in_degrees = adj.sum(dim=1)
norm = th.pow(in_degrees, -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
feat = feat * norm
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
feat = th.matmul(feat, self.weight)
rst = adj @ feat
else:
# aggregate first then mult W
rst = adj @ feat
rst = th.matmul(rst, self.weight)
if self._norm:
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
class DenseSAGEConv(nn.Module):
"""GraphSAGE layer where the graph structure is given by an
adjacency matrix.
We recommend to use this module when inducing GraphSAGE operations
on dense graphs / k-hop graphs.
Note that we only support gcn aggregator in DenseSAGEConv.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
feat_drop : float, optional
Dropout rate on features. Default: 0.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization oto the updated node features.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
See also
--------
SAGEConv
"""
def __init__(self,
in_feats,
out_feats,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
self.fc = nn.Linear(in_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, feat, adj):
r"""Compute (Dense) Graph SAGE layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
adj = adj.float().to(feat.device)
feat = self.feat_drop(feat)
in_degrees = adj.sum(dim=1).unsqueeze(-1)
h_neigh = (adj @ feat + feat) / (in_degrees + 1)
rst = self.fc(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self._norm is not None:
rst = self._norm(rst)
return rst
class DenseChebConv(nn.Module):
r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional
Neural Networks on Graphs with Fast Localized Spectral Filtering
<https://arxiv.org/pdf/1606.09375.pdf>`__.
We recommend to use this module when inducing ChebConv operations on dense
graphs / k-hop graphs.
Parameters
----------
in_feats: int
Number of input features.
out_feats: int
Number of output features.
k : int
Chebyshev filter size.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
See also
--------
ChebConv
"""
def __init__(self,
in_feats,
out_feats,
k,
bias=True):
super(DenseChebConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._k = k
self.W = nn.Parameter(th.Tensor(k, in_feats, out_feats))
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.bias is not None:
init.zeros_(self.bias)
for i in range(self._k):
init.xavier_normal_(self.W[i], init.calculate_gain('relu'))
def forward(self, feat, adj):
r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
A = adj.to(feat)
num_nodes = A.shape[0]
in_degree = 1 / A.sum(dim=1).clamp(min=1).sqrt()
D_invsqrt = th.diag(in_degree)
I = th.eye(num_nodes).to(A)
L = I - D_invsqrt @ A @ D_invsqrt
lambda_ = th.eig(L)[0][:, 0]
lambda_max = lambda_.max()
L_hat = 2 * L / lambda_max - I
Z = [th.eye(num_nodes).to(A)]
for i in range(1, self._k):
if i == 1:
Z.append(L_hat)
else:
Z.append(2 * L_hat @ Z[-1] - Z[-2])
Zs = th.stack(Z, 0) # (k, n, n)
Zh = (Zs @ feat.unsqueeze(0) @ self.W)
Zh = Zh.sum(0)
if self.bias is not None:
Zh = Zh + self.bias
return Zh
"""Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch as th
import torch.nn as nn
import numpy as np
......@@ -178,17 +178,6 @@ class GlobalAttentionPooling(nn.Module):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
self.feat_nn = feat_nn
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
for p in self.gate_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if self.feat_nn:
for p in self.feat_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat, graph):
r"""Compute global attention pooling.
......
"""Module for graph transformation methods."""
"""Module for graph transformation utilities."""
import numpy as np
from scipy import sparse
from ._ffi.function import _init_api
from .graph import DGLGraph
from .batched_graph import BatchedDGLGraph
from .graph_index import from_coo
from .batched_graph import BatchedDGLGraph, unbatch
from .backend import asnumpy, tensor
__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected']
__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max']
def line_graph(g, backtracking=True, shared=False):
......@@ -12,6 +19,7 @@ def line_graph(g, backtracking=True, shared=False):
Parameters
----------
g : dgl.DGLGraph
The input graph.
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
......@@ -26,6 +34,88 @@ def line_graph(g, backtracking=True, shared=False):
node_frame = g._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
def khop_adj(g, k):
"""Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of :math:`g`,
where a row represents the destination and a column represents the source.
Parameters
----------
g : dgl.DGLGraph
The input graph.
k : int
The :math:`k` in :math:`A^k`.
Returns
-------
tensor
The returned tensor, dtype is ``np.float32``.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0])
>>> dgl.khop_adj(g, 1)
tensor([[1., 0., 0., 0., 1.],
[1., 1., 0., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 1., 1., 0.],
[0., 0., 0., 1., 1.]])
>>> dgl.khop_adj(g, 3)
tensor([[1., 0., 1., 3., 3.],
[3., 1., 0., 1., 3.],
[3., 3., 1., 0., 1.],
[1., 3., 3., 1., 0.],
[0., 1., 3., 3., 1.]])
"""
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
return tensor(adj_k.todense().astype(np.float32))
def khop_graph(g, k):
"""Return the graph that includes all :math:`k`-hop neighbors of the given graph as edges.
The adjacency matrix of the returned graph is :math:`A^k`
(where :math:`A` is the adjacency matrix of :math:`g`).
Parameters
----------
g : dgl.DGLGraph
The input graph.
k : int
The :math:`k` in `k`-hop graph.
Returns
-------
dgl.DGLGraph
The returned ``DGLGraph``.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0])
>>> dgl.khop_graph(g, 1)
DGLGraph(num_nodes=5, num_edges=10,
ndata_schemes={}
edata_schemes={})
>>> dgl.khop_graph(g, 3)
DGLGraph(num_nodes=5, num_edges=40,
ndata_schemes={}
edata_schemes={})
"""
n = g.number_of_nodes()
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
adj_k = adj_k.tocoo()
multiplicity = adj_k.data
row = np.repeat(adj_k.row, multiplicity)
col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future.
return DGLGraph(from_coo(n, row, col, True, True))
def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph
......@@ -46,6 +136,7 @@ def reverse(g, share_ndata=False, share_edata=False):
Parameters
----------
g : dgl.DGLGraph
The input graph.
share_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes.
Otherwise the reversed graph will not be initialized with node attributes.
......@@ -169,4 +260,49 @@ def to_bidirected(g, readonly=True):
newgidx = _CAPI_DGLToBidirectedMutableGraph(g._graph)
return DGLGraph(newgidx)
def laplacian_lambda_max(g):
"""Return the largest eigenvalue of the normalized symmetric laplacian of g.
The eigenvalue of the normalized symmetric of any graph is less than or equal to 2,
ref: https://en.wikipedia.org/wiki/Laplacian_matrix#Properties
Parameters
----------
g : DGLGraph or BatchedDGLGraph
The input graph, it should be an undirected graph.
Returns
-------
list :
* If the input g is a DGLGraph, the returned value would be
a list with one element, indicating the largest eigenvalue of g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 4, 0, 1, 2, 3])
>>> dgl.laplacian_lambda_max(g)
[1.809016994374948]
"""
if isinstance(g, BatchedDGLGraph):
g_arr = unbatch(g)
else:
g_arr = [g]
rst = []
for g_i in g_arr:
n = g_i.number_of_nodes()
adj = g_i.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
norm = sparse.diags(asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float)
laplacian = sparse.eye(n) - norm * adj * norm
rst.append(sparse.linalg.eigs(laplacian, 1, which='LM',
return_eigenvectors=False)[0].real)
return rst
_init_api("dgl.transform")
......@@ -110,6 +110,11 @@ def min(x, dim):
def prod(x, dim):
"""Computes the prod of array elements over given axes"""
pass
def matmul(a, b):
"""Compute Matrix Multiplication between a and b"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -83,6 +83,9 @@ def min(x, dim):
def prod(x, dim):
return x.prod(dim)
def matmul(a, b):
return nd.dot(a, b)
record_grad = autograd.record
......
......@@ -79,6 +79,9 @@ def min(x, dim):
def prod(x, dim):
return x.prod(dim)
def matmul(a, b):
return a @ b
class record_grad(object):
def __init__(self):
pass
......
......@@ -112,6 +112,56 @@ def test_bidirected_graph():
_test(False, True)
_test(False, False)
def test_khop_graph():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def test_khop_adj():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(3):
adj = F.tensor(dgl.khop_adj(g, k))
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop adj to do message passing for one time.
h_1 = F.matmul(adj, feat)
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def test_laplacian_lambda_max():
N = 20
eps = 1e-6
# test DGLGraph
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
l_max = dgl.laplacian_lambda_max(g)
assert (l_max[0] < 2 + eps)
# test BatchedDGLGraph
N_arr = [20, 30, 10, 12]
bg = dgl.batch([
dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for N in N_arr
])
l_max_arr = dgl.laplacian_lambda_max(bg)
assert len(l_max_arr) == len(N_arr)
for l_max in l_max_arr:
assert l_max < 2 + eps
if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
......@@ -119,3 +169,6 @@ if __name__ == '__main__':
test_reverse_shared_frames()
test_simple_graph()
test_bidirected_graph()
test_khop_adj()
test_khop_graph()
test_laplacian_lambda_max()
......@@ -20,7 +20,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2, norm=False, bias=True)
if F.gpu_ctx():
conv.cuda()
conv = conv.to(ctx)
print(conv)
# test#1: basic
h0 = F.ones((3, 5))
......@@ -37,7 +37,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv.cuda()
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
......@@ -51,7 +51,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv.cuda()
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
......@@ -81,15 +81,15 @@ def _S2AXWb(A, N, X, W, b):
return Y + b
def test_tgconv():
def test_tagconv():
g = dgl.DGLGraph(nx.path_graph(3))
ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx)
norm = th.pow(g.in_degrees().float(), -0.5)
conv = nn.TGConv(5, 2, bias=True)
conv = nn.TAGConv(5, 2, bias=True)
if F.gpu_ctx():
conv.cuda()
conv = conv.to(ctx)
print(conv)
# test#1: basic
......@@ -102,27 +102,27 @@ def test_tgconv():
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))
conv = nn.TGConv(5, 2)
conv = nn.TAGConv(5, 2)
if F.gpu_ctx():
conv.cuda()
conv = conv.to(ctx)
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert h1.shape[-1] == 2
# test rest_parameters
# test reset_parameters
old_weight = deepcopy(conv.lin.weight.data)
conv.reset_parameters()
new_weight = conv.lin.weight.data
assert not F.allclose(old_weight, new_weight)
def test_set2set():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
if F.gpu_ctx():
s2s.cuda()
s2s = s2s.to(ctx)
print(s2s)
# test#1: basic
......@@ -139,11 +139,12 @@ def test_set2set():
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
if F.gpu_ctx():
gap.cuda()
gap = gap.to(ctx)
print(gap)
# test#1: basic
......@@ -158,6 +159,7 @@ def test_glob_att_pool():
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
......@@ -168,6 +170,12 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
if F.gpu_ctx():
sum_pool = sum_pool.to(ctx)
avg_pool = avg_pool.to(ctx)
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx)
h1 = sum_pool(h0, g)
assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g)
......@@ -181,6 +189,8 @@ def test_simple_pool():
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5))
if F.gpu_ctx():
h0 = h0.to(ctx)
h1 = sum_pool(h0, bg)
truth = th.stack([F.sum(h0[:15], 0),
......@@ -210,15 +220,16 @@ def test_simple_pool():
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx():
st_enc_0.cuda()
st_enc_1.cuda()
st_dec.cuda()
st_enc_0 = st_enc_0.to(ctx)
st_enc_1 = st_enc_1.to(ctx)
st_dec = st_dec.to(ctx)
print(st_enc_0, st_enc_1, st_dec)
# test#1: basic
......@@ -354,6 +365,207 @@ def test_rgcn():
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
def test_gat_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5))
if F.gpu_ctx():
gat = gat.to(ctx)
feat = feat.to(ctx)
h = gat(feat, g)
assert h.shape[-1] == 2 and h.shape[-2] == 4
def test_sage_conv():
for aggre_type in ['mean', 'pool', 'gcn', 'lstm']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
feat = feat.to(ctx)
h = sage(feat, g)
assert h.shape[-1] == 10
def test_sgc_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
# not cached
sgc = nn.SGConv(5, 10, 3)
feat = F.randn((100, 5))
if F.gpu_ctx():
sgc = sgc.to(ctx)
feat = feat.to(ctx)
h = sgc(feat, g)
assert h.shape[-1] == 10
# cached
sgc = nn.SGConv(5, 10, 3, True)
if F.gpu_ctx():
sgc = sgc.to(ctx)
h_0 = sgc(feat, g)
h_1 = sgc(feat + 1, g)
assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == 10
def test_appnp_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((100, 5))
if F.gpu_ctx():
appnp = appnp.to(ctx)
feat = feat.to(ctx)
h = appnp(feat, g)
assert h.shape[-1] == 5
def test_gin_conv():
for aggregator_type in ['mean', 'max', 'sum']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gin = nn.GINConv(
th.nn.Linear(5, 12),
aggregator_type
)
feat = F.randn((100, 5))
if F.gpu_ctx():
gin = gin.to(ctx)
feat = feat.to(ctx)
h = gin(feat, g)
assert h.shape[-1] == 12
def test_agnn_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
agnn = nn.AGNNConv(1)
feat = F.randn((100, 5))
if F.gpu_ctx():
agnn = agnn.to(ctx)
feat = feat.to(ctx)
h = agnn(feat, g)
assert h.shape[-1] == 5
def test_gated_graph_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
ggconv = nn.GatedGraphConv(5, 10, 5, 3)
etypes = th.arange(g.number_of_edges()) % 3
feat = F.randn((100, 5))
if F.gpu_ctx():
ggconv = ggconv.to(ctx)
feat = feat.to(ctx)
etypes = etypes.to(ctx)
h = ggconv(feat, etypes, g)
# current we only do shape check
assert h.shape[-1] == 10
def test_nn_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((100, 5))
efeat = F.randn((g.number_of_edges(), 4))
if F.gpu_ctx():
nnconv = nnconv.to(ctx)
feat = feat.to(ctx)
efeat = efeat.to(ctx)
h = nnconv(feat, efeat, g)
# currently we only do shape check
assert h.shape[-1] == 10
def test_gmm_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
feat = F.randn((100, 5))
pseudo = F.randn((g.number_of_edges(), 3))
if F.gpu_ctx():
gmmconv = gmmconv.to(ctx)
feat = feat.to(ctx)
pseudo = pseudo.to(ctx)
h = gmmconv(feat, pseudo, g)
# currently we only do shape check
assert h.shape[-1] == 10
def test_dense_graph_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
conv = nn.GraphConv(5, 2, norm=False, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
conv = conv.to(ctx)
dense_conv = dense_conv.to(ctx)
feat = feat.to(ctx)
out_conv = conv(feat, g)
out_dense_conv = dense_conv(feat, adj)
assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
sage = nn.SAGEConv(5, 2, 'gcn',)
dense_sage = nn.DenseSAGEConv(5, 2)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
dense_sage = dense_sage.to(ctx)
feat = feat.to(ctx)
out_sage = sage(feat, g)
out_dense_sage = dense_sage(feat, adj)
assert F.allclose(out_sage, out_dense_sage)
def test_dense_cheb_conv():
for k in range(1, 4):
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
cheb = nn.ChebConv(5, 2, k)
dense_cheb = nn.DenseChebConv(5, 2, k)
for i in range(len(cheb.fc)):
dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
if cheb.bias is not None:
dense_cheb.bias.data = cheb.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx)
feat = feat.to(ctx)
out_cheb = cheb(feat, g)
out_dense_cheb = dense_cheb(feat, adj)
assert F.allclose(out_cheb, out_dense_cheb)
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
......@@ -362,3 +574,17 @@ if __name__ == '__main__':
test_simple_pool()
test_set_trans()
test_rgcn()
test_tagconv()
test_gat_conv()
test_sage_conv()
test_sgc_conv()
test_appnp_conv()
test_gin_conv()
test_agnn_conv()
test_gated_graph_conv()
test_nn_conv()
test_gmm_conv()
test_dense_graph_conv()
test_dense_sage_conv()
test_dense_cheb_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment