"...text-generation-inference.git" did not exist on "a60fa8406abd98d41e2bfafaf6f81f3dd6044b15"
Unverified Commit 248bece6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] GINEConv (#3934)



* Update

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent df7a6121
......@@ -23,6 +23,7 @@ Conv Layers
~dgl.nn.pytorch.conv.SGConv
~dgl.nn.pytorch.conv.APPNPConv
~dgl.nn.pytorch.conv.GINConv
~dgl.nn.pytorch.conv.GINEConv
~dgl.nn.pytorch.conv.GatedGraphConv
~dgl.nn.pytorch.conv.GMMConv
~dgl.nn.pytorch.conv.ChebConv
......
......@@ -9,6 +9,7 @@ from .gatconv import GATConv
from .gatv2conv import GATv2Conv
from .egatconv import EGATConv
from .ginconv import GINConv
from .gineconv import GINEConv
from .gmmconv import GMMConv
from .graphconv import GraphConv, EdgeWeightNorm
from .nnconv import NNConv
......@@ -31,8 +32,8 @@ from .egnnconv import EGNNConv
from .pnaconv import PNAConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', 'EGNNConv',
'PNAConv']
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GINEConv',
'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv',
'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv',
'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes',
'EGNNConv', 'PNAConv']
"""Torch Module for Graph Isomorphism Network layer variant with edge features"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
import torch.nn.functional as F
from torch import nn
from .... import function as fn
from ....utils import expand_as_pair
class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\sum_{j\in\mathcal{N}(i)}\mathrm{ReLU}(h_j^{l} + e_{j,i}^{l})\right)
where :math:`e_{j,i}^{l}` is the edge feature.
Parameters
----------
apply_func : callable module or None
The :math:`f_\Theta` in the formula. If not None, it will be applied to
the updated node features. The default value is None.
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
Examples
--------
>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GINEConv
>>> g = dgl.graph(([0, 1, 2], [1, 1, 3]))
>>> in_feats = 10
>>> out_feats = 20
>>> nfeat = torch.randn(g.num_nodes(), in_feats)
>>> efeat = torch.randn(g.num_edges(), in_feats)
>>> conv = GINEConv(nn.Linear(in_feats, out_feats))
>>> res = conv(g, nfeat, efeat)
>>> print(res.shape)
torch.Size([4, 20])
"""
def __init__(self,
apply_func=None,
init_eps=0,
learn_eps=False):
super(GINEConv, self).__init__()
self.apply_func = apply_func
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
def message(self, edges):
r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])}
def forward(self, graph, node_feat, edge_feat):
r"""Forward computation.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it is the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input feature size requirement of ``apply_func``.
edge_feat : torch.Tensor
Edge feature. It is a tensor of shape :math:`(E, D_{in})` where :math:`E`
is the number of edges.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output feature size of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as :math:`D_{in}`.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src
graph.edata['he'] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
......@@ -682,6 +682,29 @@ def test_gin_conv(g, idtype, aggregator_type):
gin = gin.to(ctx)
h = gin(g, feat)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gine_conv(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gine = nn.GINEConv(
th.nn.Linear(5, 12)
)
th.save(gine, tmp_buffer)
nfeat = F.randn((g.number_of_src_nodes(), 5))
efeat = F.randn((g.num_edges(), 5))
gine = gine.to(ctx)
h = gine(g, nfeat, efeat)
# test pickle
th.save(gine, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12)
gine = nn.GINEConv(None)
th.save(gine, tmp_buffer)
gine = gine.to(ctx)
h = gine(g, nfeat, efeat)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
......@@ -1441,7 +1464,7 @@ def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators',
@pytest.mark.parametrize('aggregators',
[['mean', 'max', 'sum'], ['min', 'std', 'var'], ['moment3', 'moment4', 'moment5']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment