"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0437b16497041b83d454a08e2db7e56fb52560d5"
Unverified Commit dc5035b1 authored by rudongyu's avatar rudongyu Committed by GitHub
Browse files

[NN] Add EGNN & PNA (#3901)



* add EGNN & PNA

* fix egnn issues

* fix pna issues

* update pna conv

* add doc strings

* update pnaconv

* fix unused args issue

* fix moment aggregation issue
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 27a6eb56
...@@ -36,6 +36,8 @@ Conv Layers ...@@ -36,6 +36,8 @@ Conv Layers
~dgl.nn.pytorch.conv.GCN2Conv ~dgl.nn.pytorch.conv.GCN2Conv
~dgl.nn.pytorch.conv.HGTConv ~dgl.nn.pytorch.conv.HGTConv
~dgl.nn.pytorch.conv.GroupRevRes ~dgl.nn.pytorch.conv.GroupRevRes
~dgl.nn.pytorch.conv.EGNNConv
~dgl.nn.pytorch.conv.PNAConv
Dense Conv Layers Dense Conv Layers
---------------------------------------- ----------------------------------------
......
...@@ -27,9 +27,12 @@ from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention ...@@ -27,9 +27,12 @@ from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv from .hgtconv import HGTConv
from .grouprevres import GroupRevRes from .grouprevres import GroupRevRes
from .egnnconv import EGNNConv
from .pnaconv import PNAConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes'] 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', 'EGNNConv',
'PNAConv']
"""Torch Module for E(n) Equivariant Graph Convolutional Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from .... import function as fn
class EGNNConv(nn.Module):
r"""Equivariant Graph Convolutional Layer from `E(n) Equivariant Graph
Neural Networks <https://arxiv.org/abs/2102.09844>`__
.. math::
m_{ij}=\phi_e(h_i^l, h_j^l, ||x_i^l-x_j^l||^2, a_{ij})
x_i^{l+1} = x_i^l + C\sum_{j\in\mathcal{N}(i)}(x_i^l-x_j^l)\phi_x(m_{ij})
m_i = \sum_{j\in\mathcal{N}(i)} m_{ij}
h_i^{l+1} = \phi_h(h_i^l, m_i)
where :math:`h_i`, :math:`x_i`, :math:`a_{ij}` are node features, coordinate
features, and edge features respectively. :math:`\phi_e`, :math:`\phi_h`, and
:math:`\phi_x` are two-layer MLPs. :math:`C` is a constant for normalization,
computed as :math:`1/|\mathcal{N}(i)|`.
Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:`h_i^l`.
hidden_size : int
Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in
:math:`\phi_e, \phi_x, \phi_h`.
out_size : int
Output feature size; i.e. the size of :math:`h_i^{l+1}`.
edge_feat_size : int, optional
Edge feature size; i.e. the size of :math:`a_{ij}`. Default: 0.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGNNConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2)
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
"""
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
self.in_size = in_size
self.hidden_size = hidden_size
self.out_size = out_size
self.edge_feat_size = edge_feat_size
act_fn = nn.SiLU()
# \phi_e
self.edge_mlp = nn.Sequential(
# +1 for the radial feature: ||x_i - x_j||^2
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
)
def message(self, edges):
"""message function for EGNN"""
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
return {'msg_x': msg_x, 'msg_h': msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r"""
Description
-----------
Compute EGNN layer.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size.
coord_feat : torch.Tensor
The coordinate feature of shape :math:`(N, h_x)`. :math:`N` is the
number of nodes, and :math:`h_x` can be any positive integer.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size.
Returns
-------
node_feat_out : torch.Tensor
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
is the same as out_size.
coord_feat_out: torch.Tensor
The output coordinate feature of shape :math:`(N, h_x)` where :math:`h_x`
is the same as the input coordinate feature dimension.
"""
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
x = coord_feat + x_neigh
return h, x
"""Torch Module for Principal Neighbourhood Aggregation Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn as nn
def aggregate_mean(h):
"""mean aggregation"""
return torch.mean(h, dim=1)
def aggregate_max(h):
"""max aggregation"""
return torch.max(h, dim=1)[0]
def aggregate_min(h):
"""min aggregation"""
return torch.min(h, dim=1)[0]
def aggregate_sum(h):
"""sum aggregation"""
return torch.sum(h, dim=1)
def aggregate_std(h):
"""standard deviation aggregation"""
return torch.sqrt(aggregate_var(h) + 1e-30)
def aggregate_var(h):
"""variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1)
h_mean = torch.mean(h, dim=1)
var = torch.relu(h_mean_squares - h_mean * h_mean)
return var
def _aggregate_moment(h, n):
"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
h_mean = torch.mean(h, dim=1, keepdim=True)
h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n)
return rooted_h_n
def aggregate_moment_3(h):
"""moment aggregation with n=3"""
return _aggregate_moment(h, n=3)
def aggregate_moment_4(h):
"""moment aggregation with n=4"""
return _aggregate_moment(h, n=4)
def aggregate_moment_5(h):
"""moment aggregation with n=5"""
return _aggregate_moment(h, n=5)
def scale_identity(h):
"""identity scaling (no scaling operation)"""
return h
def scale_amplification(h, D, delta):
"""amplification scaling"""
return h * (np.log(D + 1) / delta)
def scale_attenuation(h, D, delta):
"""attenuation scaling"""
return h * (delta / np.log(D + 1))
AGGREGATORS = {
'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3,
'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5
}
SCALERS = {
'identity': scale_identity,
'amplification': scale_amplification,
'attenuation': scale_attenuation
}
class PNAConvTower(nn.Module):
"""A single PNA tower in PNA layers"""
def __init__(self, in_size, out_size, aggregators, scalers,
delta, dropout=0., edge_feat_size=0):
super(PNAConvTower, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.aggregators = aggregators
self.scalers = scalers
self.delta = delta
self.edge_feat_size = edge_feat_size
self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size)
self.dropout = nn.Dropout(dropout)
self.batchnorm = nn.BatchNorm1d(out_size)
def reduce_func(self, nodes):
"""reduce function for PNA layer:
tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg']
degree = msg.size(1)
h = torch.cat([aggregator(msg) for aggregator in self.aggregators], dim=1)
h = torch.cat([
scaler(h, D=degree, delta=self.delta) if scaler is not scale_identity else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
def message(self, edges):
"""message function for PNA layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f)}
def forward(self, graph, node_feat, edge_feat=None):
"""compute the forward pass of a single tower in PNA convolution layer"""
# calculate graph normalization factors
snorm_n = torch.cat(
[torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()],
dim=0
).sqrt()
with graph.local_scope():
graph.ndata['h'] = node_feat
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.update_all(self.message, self.reduce_func)
h = self.U(
torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
)
h = h * snorm_n
return self.dropout(self.batchnorm(h))
class PNAConv(nn.Module):
r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation
for Graph Nets <https://arxiv.org/abs/2004.05718>`__
A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the
input features, and computes the message passing as below.
.. math::
h_i^(l+1) = U(h_i^l, \oplus_{(i,j)\in E}M(h_i^l, e_{i,j}, h_j^l))
where :math:`h_i` and :math:`e_{i,j}` are node features and edge features, respectively.
:math:`M` and :math:`U` are MLPs, taking the concatenation of input for computing
output features. :math:`\oplus` represents the combination of various aggregators
and scalers. Aggregators aggregate messages from neighbours and scalers scale the
aggregated messages in different ways. :math:`\oplus` concatenates the output features
of each combination.
The output of multiple towers are concatenated and fed into a linear mixing layer for the
final output.
Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:`h_i^l`.
out_size : int
Output feature size; i.e. the size of :math:`h_i^{l+1}`.
aggregators : list of str
List of aggregation function names(each aggregator specifies a way to aggregate
messages from neighbours), selected from:
* ``mean``: the mean of neighbour messages
* ``max``: the maximum of neighbour messages
* ``min``: the minimum of neighbour messages
* ``std``: the standard deviation of neighbour messages
* ``var``: the variance of neighbour messages
* ``sum``: the sum of neighbour messages
* ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation
:math:`(E[(X-E[X])^n])^{1/n}`
scalers: list of str
List of scaler function names, selected from:
* ``identity``: no scaling
* ``amplification``: multiply the aggregated message by :math:`\log(d+1)/\delta`,
where :math:`d` is the degree of the node.
* ``attenuation``: multiply the aggregated message by :math:`\delta/\log(d+1)`
delta: float
The degree-related normalization factor computed over the training set, used by scalers
for normalization. :math:`E[\log(d+1)]`, where :math:`d` is the degree for each node
in the training set.
dropout: float, optional
The dropout ratio. Default: 0.0.
num_towers: int, optional
The number of towers used. Default: 1. Note that in_size and out_size must be divisible
by num_towers.
edge_feat_size: int, optional
The edge feature size. Default: 0.
residual : bool, optional
The bool flag that determines whether to add a residual connection for the
output. Default: True. If in_size and out_size of the PNA conv layer are not
the same, this flag will be set as False forcibly.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import PNAConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
super(PNAConv, self).__init__()
aggregators = [AGGREGATORS[aggr] for aggr in aggregators]
scalers = [SCALERS[scale] for scale in scalers]
self.in_size = in_size
self.out_size = out_size
assert in_size % num_towers == 0, 'in_size must be divisible by num_towers'
assert out_size % num_towers == 0, 'out_size must be divisible by num_towers'
self.tower_in_size = in_size // num_towers
self.tower_out_size = out_size // num_towers
self.edge_feat_size = edge_feat_size
self.residual = residual
if self.in_size != self.out_size:
self.residual = False
self.towers = nn.ModuleList([
PNAConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.mixing_layer = nn.Sequential(
nn.Linear(out_size, out_size),
nn.LeakyReLU()
)
def forward(self, graph, node_feat, edge_feat=None):
r"""
Description
-----------
Compute PNA layer.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size.
Returns
-------
torch.Tensor
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
should be the same as out_size.
"""
h_cat = torch.cat([
tower(
graph,
node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
edge_feat
)
for ti, tower in enumerate(self.towers)
], dim=1)
h_out = self.mixing_layer(h_cat)
# add residual connection
if self.residual:
h_out = h_out + node_feat
return h_out
...@@ -1423,3 +1423,40 @@ def test_group_rev_res(idtype): ...@@ -1423,3 +1423,40 @@ def test_group_rev_res(idtype):
conv = nn.GraphConv(feats // groups, feats // groups) conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev) model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h) model(g, h)
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('edge_feat_size', [16, 10, 0])
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
dev = F.ctx()
num_nodes = 5
num_edges = 20
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, in_size).to(dev)
x = th.randn(num_nodes, 3).to(dev)
e = th.randn(num_edges, edge_feat_size).to(dev)
model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
model(g, h, x, e)
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators',
[['mean', 'max', 'sum'], ['min', 'std', 'var'], ['moment3', 'moment4', 'moment5']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
@pytest.mark.parametrize('dropout', [0., 0.1])
@pytest.mark.parametrize('num_towers', [1, 4])
@pytest.mark.parametrize('edge_feat_size', [16, 0])
@pytest.mark.parametrize('residual', [True, False])
def test_pna_conv(in_size, out_size, aggregators, scalers, delta,
dropout, num_towers, edge_feat_size, residual):
dev = F.ctx()
num_nodes = 5
num_edges = 20
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, in_size).to(dev)
e = th.randn(num_edges, edge_feat_size).to(dev)
model = nn.PNAConv(in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment