Unverified Commit dc5035b1 authored by rudongyu's avatar rudongyu Committed by GitHub
Browse files

[NN] Add EGNN & PNA (#3901)



* add EGNN & PNA

* fix egnn issues

* fix pna issues

* update pna conv

* add doc strings

* update pnaconv

* fix unused args issue

* fix moment aggregation issue
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 27a6eb56
......@@ -36,6 +36,8 @@ Conv Layers
~dgl.nn.pytorch.conv.GCN2Conv
~dgl.nn.pytorch.conv.HGTConv
~dgl.nn.pytorch.conv.GroupRevRes
~dgl.nn.pytorch.conv.EGNNConv
~dgl.nn.pytorch.conv.PNAConv
Dense Conv Layers
----------------------------------------
......
......@@ -27,9 +27,12 @@ from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv
from .grouprevres import GroupRevRes
from .egnnconv import EGNNConv
from .pnaconv import PNAConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes']
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', 'EGNNConv',
'PNAConv']
"""Torch Module for E(n) Equivariant Graph Convolutional Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from .... import function as fn
class EGNNConv(nn.Module):
r"""Equivariant Graph Convolutional Layer from `E(n) Equivariant Graph
Neural Networks <https://arxiv.org/abs/2102.09844>`__
.. math::
m_{ij}=\phi_e(h_i^l, h_j^l, ||x_i^l-x_j^l||^2, a_{ij})
x_i^{l+1} = x_i^l + C\sum_{j\in\mathcal{N}(i)}(x_i^l-x_j^l)\phi_x(m_{ij})
m_i = \sum_{j\in\mathcal{N}(i)} m_{ij}
h_i^{l+1} = \phi_h(h_i^l, m_i)
where :math:`h_i`, :math:`x_i`, :math:`a_{ij}` are node features, coordinate
features, and edge features respectively. :math:`\phi_e`, :math:`\phi_h`, and
:math:`\phi_x` are two-layer MLPs. :math:`C` is a constant for normalization,
computed as :math:`1/|\mathcal{N}(i)|`.
Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:`h_i^l`.
hidden_size : int
Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in
:math:`\phi_e, \phi_x, \phi_h`.
out_size : int
Output feature size; i.e. the size of :math:`h_i^{l+1}`.
edge_feat_size : int, optional
Edge feature size; i.e. the size of :math:`a_{ij}`. Default: 0.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGNNConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2)
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
"""
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
self.in_size = in_size
self.hidden_size = hidden_size
self.out_size = out_size
self.edge_feat_size = edge_feat_size
act_fn = nn.SiLU()
# \phi_e
self.edge_mlp = nn.Sequential(
# +1 for the radial feature: ||x_i - x_j||^2
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
)
def message(self, edges):
"""message function for EGNN"""
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
return {'msg_x': msg_x, 'msg_h': msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r"""
Description
-----------
Compute EGNN layer.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size.
coord_feat : torch.Tensor
The coordinate feature of shape :math:`(N, h_x)`. :math:`N` is the
number of nodes, and :math:`h_x` can be any positive integer.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size.
Returns
-------
node_feat_out : torch.Tensor
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
is the same as out_size.
coord_feat_out: torch.Tensor
The output coordinate feature of shape :math:`(N, h_x)` where :math:`h_x`
is the same as the input coordinate feature dimension.
"""
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
x = coord_feat + x_neigh
return h, x
"""Torch Module for Principal Neighbourhood Aggregation Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn as nn
def aggregate_mean(h):
"""mean aggregation"""
return torch.mean(h, dim=1)
def aggregate_max(h):
"""max aggregation"""
return torch.max(h, dim=1)[0]
def aggregate_min(h):
"""min aggregation"""
return torch.min(h, dim=1)[0]
def aggregate_sum(h):
"""sum aggregation"""
return torch.sum(h, dim=1)
def aggregate_std(h):
"""standard deviation aggregation"""
return torch.sqrt(aggregate_var(h) + 1e-30)
def aggregate_var(h):
"""variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1)
h_mean = torch.mean(h, dim=1)
var = torch.relu(h_mean_squares - h_mean * h_mean)
return var
def _aggregate_moment(h, n):
"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
h_mean = torch.mean(h, dim=1, keepdim=True)
h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n)
return rooted_h_n
def aggregate_moment_3(h):
"""moment aggregation with n=3"""
return _aggregate_moment(h, n=3)
def aggregate_moment_4(h):
"""moment aggregation with n=4"""
return _aggregate_moment(h, n=4)
def aggregate_moment_5(h):
"""moment aggregation with n=5"""
return _aggregate_moment(h, n=5)
def scale_identity(h):
"""identity scaling (no scaling operation)"""
return h
def scale_amplification(h, D, delta):
"""amplification scaling"""
return h * (np.log(D + 1) / delta)
def scale_attenuation(h, D, delta):
"""attenuation scaling"""
return h * (delta / np.log(D + 1))
AGGREGATORS = {
'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3,
'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5
}
SCALERS = {
'identity': scale_identity,
'amplification': scale_amplification,
'attenuation': scale_attenuation
}
class PNAConvTower(nn.Module):
"""A single PNA tower in PNA layers"""
def __init__(self, in_size, out_size, aggregators, scalers,
delta, dropout=0., edge_feat_size=0):
super(PNAConvTower, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.aggregators = aggregators
self.scalers = scalers
self.delta = delta
self.edge_feat_size = edge_feat_size
self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size)
self.dropout = nn.Dropout(dropout)
self.batchnorm = nn.BatchNorm1d(out_size)
def reduce_func(self, nodes):
"""reduce function for PNA layer:
tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg']
degree = msg.size(1)
h = torch.cat([aggregator(msg) for aggregator in self.aggregators], dim=1)
h = torch.cat([
scaler(h, D=degree, delta=self.delta) if scaler is not scale_identity else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
def message(self, edges):
"""message function for PNA layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f)}
def forward(self, graph, node_feat, edge_feat=None):
"""compute the forward pass of a single tower in PNA convolution layer"""
# calculate graph normalization factors
snorm_n = torch.cat(
[torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()],
dim=0
).sqrt()
with graph.local_scope():
graph.ndata['h'] = node_feat
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.update_all(self.message, self.reduce_func)
h = self.U(
torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
)
h = h * snorm_n
return self.dropout(self.batchnorm(h))
class PNAConv(nn.Module):
r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation
for Graph Nets <https://arxiv.org/abs/2004.05718>`__
A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the
input features, and computes the message passing as below.
.. math::
h_i^(l+1) = U(h_i^l, \oplus_{(i,j)\in E}M(h_i^l, e_{i,j}, h_j^l))
where :math:`h_i` and :math:`e_{i,j}` are node features and edge features, respectively.
:math:`M` and :math:`U` are MLPs, taking the concatenation of input for computing
output features. :math:`\oplus` represents the combination of various aggregators
and scalers. Aggregators aggregate messages from neighbours and scalers scale the
aggregated messages in different ways. :math:`\oplus` concatenates the output features
of each combination.
The output of multiple towers are concatenated and fed into a linear mixing layer for the
final output.
Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:`h_i^l`.
out_size : int
Output feature size; i.e. the size of :math:`h_i^{l+1}`.
aggregators : list of str
List of aggregation function names(each aggregator specifies a way to aggregate
messages from neighbours), selected from:
* ``mean``: the mean of neighbour messages
* ``max``: the maximum of neighbour messages
* ``min``: the minimum of neighbour messages
* ``std``: the standard deviation of neighbour messages
* ``var``: the variance of neighbour messages
* ``sum``: the sum of neighbour messages
* ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation
:math:`(E[(X-E[X])^n])^{1/n}`
scalers: list of str
List of scaler function names, selected from:
* ``identity``: no scaling
* ``amplification``: multiply the aggregated message by :math:`\log(d+1)/\delta`,
where :math:`d` is the degree of the node.
* ``attenuation``: multiply the aggregated message by :math:`\delta/\log(d+1)`
delta: float
The degree-related normalization factor computed over the training set, used by scalers
for normalization. :math:`E[\log(d+1)]`, where :math:`d` is the degree for each node
in the training set.
dropout: float, optional
The dropout ratio. Default: 0.0.
num_towers: int, optional
The number of towers used. Default: 1. Note that in_size and out_size must be divisible
by num_towers.
edge_feat_size: int, optional
The edge feature size. Default: 0.
residual : bool, optional
The bool flag that determines whether to add a residual connection for the
output. Default: True. If in_size and out_size of the PNA conv layer are not
the same, this flag will be set as False forcibly.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import PNAConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
super(PNAConv, self).__init__()
aggregators = [AGGREGATORS[aggr] for aggr in aggregators]
scalers = [SCALERS[scale] for scale in scalers]
self.in_size = in_size
self.out_size = out_size
assert in_size % num_towers == 0, 'in_size must be divisible by num_towers'
assert out_size % num_towers == 0, 'out_size must be divisible by num_towers'
self.tower_in_size = in_size // num_towers
self.tower_out_size = out_size // num_towers
self.edge_feat_size = edge_feat_size
self.residual = residual
if self.in_size != self.out_size:
self.residual = False
self.towers = nn.ModuleList([
PNAConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.mixing_layer = nn.Sequential(
nn.Linear(out_size, out_size),
nn.LeakyReLU()
)
def forward(self, graph, node_feat, edge_feat=None):
r"""
Description
-----------
Compute PNA layer.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size.
Returns
-------
torch.Tensor
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
should be the same as out_size.
"""
h_cat = torch.cat([
tower(
graph,
node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
edge_feat
)
for ti, tower in enumerate(self.towers)
], dim=1)
h_out = self.mixing_layer(h_cat)
# add residual connection
if self.residual:
h_out = h_out + node_feat
return h_out
......@@ -1423,3 +1423,40 @@ def test_group_rev_res(idtype):
conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h)
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('edge_feat_size', [16, 10, 0])
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
dev = F.ctx()
num_nodes = 5
num_edges = 20
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, in_size).to(dev)
x = th.randn(num_nodes, 3).to(dev)
e = th.randn(num_edges, edge_feat_size).to(dev)
model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
model(g, h, x, e)
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators',
[['mean', 'max', 'sum'], ['min', 'std', 'var'], ['moment3', 'moment4', 'moment5']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
@pytest.mark.parametrize('dropout', [0., 0.1])
@pytest.mark.parametrize('num_towers', [1, 4])
@pytest.mark.parametrize('edge_feat_size', [16, 0])
@pytest.mark.parametrize('residual', [True, False])
def test_pna_conv(in_size, out_size, aggregators, scalers, delta,
dropout, num_towers, edge_feat_size, residual):
dev = F.ctx()
num_nodes = 5
num_edges = 20
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, in_size).to(dev)
e = th.randn(num_edges, edge_feat_size).to(dev)
model = nn.PNAConv(in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment