"src/vscode:/vscode.git/clone" did not exist on "658e24e86c4c52ee14244ab7a7113f5bf353186e"
Unverified Commit 51c65097 authored by Kamil Kamiński's avatar Kamil Kamiński Committed by GitHub
Browse files

[NN] Add EGATConv nn.module (#3425)



* added nn pytorch egatconv

* aligned with test build

* aligned with test build

* fixed wihite spaces

* fixed wihite spaces

* fixed wihite spaces

* added missing egatconv in imports

* added indentation in forward

* GATConv based implementation

* removed **kw_args

* added dgl relative imports

* PR corrections

* added DGL Error to EGATConv imports

* Update test_nn.py
Co-authored-by: default avatarArgusmocny <k.kaminski@cent.uw.edu.pl>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent a9c83bce
...@@ -270,7 +270,7 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba ...@@ -270,7 +270,7 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba
1. [**Covid-19 Detection from Chest X-ray and Patient Metadata using Graph Convolutional Neural Networks**](https://arxiv.org/abs/2105.09720), *Thosini Bamunu Mudiyanselage, Nipuna Senanayake, Chunyan Ji, Yi Pan, Yanqing Zhang* 1. [**Covid-19 Detection from Chest X-ray and Patient Metadata using Graph Convolutional Neural Networks**](https://arxiv.org/abs/2105.09720), *Thosini Bamunu Mudiyanselage, Nipuna Senanayake, Chunyan Ji, Yi Pan, Yanqing Zhang*
1. [**Graph neural networks and sequence embeddings enable the prediction and design of the cofactor specificity of Rossmann fold proteins**](https://www.biorxiv.org/content/10.1101/2021.05.05.440912v2), bioRxiv'21, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz* 1. [**Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in Rossmann fold proteins**](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbab371/6375059), Briefings in Bioinformatics, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz*
1. [**LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations**](https://arxiv.org/pdf/2106.01093.pdf), ACL'21, *Ruisheng Cao, Lu Chen, Zhi Chen, Yanbin Zhao, Su Zhu, Kai Yu* 1. [**LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations**](https://arxiv.org/pdf/2106.01093.pdf), ACL'21, *Ruisheng Cao, Lu Chen, Zhi Chen, Yanbin Zhao, Su Zhu, Kai Yu*
......
...@@ -45,6 +45,14 @@ GATConv ...@@ -45,6 +45,14 @@ GATConv
:members: forward :members: forward
:show-inheritance: :show-inheritance:
EGATConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.conv.EGATConv
:members: forward
:show-inheritance:
EdgeConv EdgeConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv ...@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv
from .chebconv import ChebConv from .chebconv import ChebConv
from .edgeconv import EdgeConv from .edgeconv import EdgeConv
from .gatconv import GATConv from .gatconv import GATConv
from .egatconv import EGATConv
from .ginconv import GINConv from .ginconv import GINConv
from .gmmconv import GMMConv from .gmmconv import GMMConv
from .graphconv import GraphConv, EdgeWeightNorm from .graphconv import GraphConv, EdgeWeightNorm
...@@ -24,8 +25,8 @@ from .dotgatconv import DotGatConv ...@@ -24,8 +25,8 @@ from .dotgatconv import DotGatConv
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv from .gcn2conv import GCN2Conv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv'] 'TWIRLSUnfoldingAndAttention', 'GCN2Conv']
"""Torch modules for graph attention networks with fully valuable edges (EGAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from torch.nn import init
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
# pylint: enable=W0235
class EGATConv(nn.Module):
r"""
Description
-----------
Apply Graph Attention Layer over input graph. EGAT is an extension
of regular `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
handling edge features, detailed description is available in `Rossmann-Toolbox
<https://pubmed.ncbi.nlm.nih.gov/34571541/>`__ (see supplementary data).
The difference appears in the method how unnormalized attention scores :math:`e_{ij}`
are obtained:
.. math::
e_{ij} &= \vec{F} (f_{ij}^{\prime})
f_{ij}^{\prime} &= \mathrm{LeakyReLU}\left(A [ h_{i} \| f_{ij} \| h_{j}]\right)
where :math:`f_{ij}^{\prime}` are edge features, :math:`\mathrm{A}` is weight matrix and
:math: `\vec{F}` is weight vector. After that resulting node features
:math:`h_{i}^{\prime}` are updated in the same way as in regular GAT.
Parameters
----------
in_node_feats : int
Input node feature size :math:`h_{i}`.
in_edge_feats : int
Input edge feature size :math:`f_{ij}`.
out_node_feats : int
Output node feature size.
out_edge_feats : int
Output edge feature size :math:`f_{ij}^{\prime}`.
num_heads : int
Number of attention heads.
bias : bool, optional
If True, add bias term to :math: `f_{ij}^{\prime}`. Defaults: ``True``.
Examples
----------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGATConv
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> graph = dgl.rand_graph((num_nodes,num_edges))
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> egat = EGATConv(in_node_feats=20,
in_edge_feats=12,
out_node_feats=15,
out_edge_feats=10,
num_heads=3)
>>> #forward pass
>>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
((8, 3, 12), (30, 3, 10))
"""
def __init__(self,
in_node_feats,
in_edge_feats,
out_node_feats,
out_edge_feats,
num_heads,
bias=True):
super().__init__()
self._num_heads = num_heads
self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats
self.fc_node = nn.Linear(in_node_feats, out_node_feats*num_heads, bias=True)
self.fc_ni = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""
Reinitialize learnable parameters.
"""
gain = init.calculate_gain('relu')
init.xavier_normal_(self.fc_node.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain)
init.xavier_normal_(self.fc_nj.weight, gain=gain)
init.xavier_normal_(self.attn, gain=gain)
init.constant_(self.bias, 0)
def forward(self, graph, nfeats, efeats, get_attention=False):
r"""
Compute new node and edge features.
Parameters
----------
graph : DGLGraph
The graph.
nfeats : torch.Tensor
The input node feature of shape :math:`(N, D_{in})`
where:
:math:`D_{in}` is size of input node feature,
:math:`N` is the number of nodes.
efeats: torch.Tensor
The input edge feature of shape :math:`(E, F_{in})`
where:
:math:`F_{in}` is size of input node feature,
:math:`E` is the number of edges.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
pair of torch.Tensor
node output features followed by edge output features
The node output feature of shape :math:`(N, H, D_{out})`
The edge output feature of shape :math:`(F, H, F_{out})`
where:
:math:`H` is the number of heads,
:math:`D_{out}` is size of output node feature,
:math:`F_{out}` is size of output edge feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`.
This is returned only when :attr: `get_attention` is ``True``.
"""
with graph.local_scope():
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue.')
# TODO allow node src and dst feats
# calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
f_ni = self.fc_ni(nfeats)
f_nj = self.fc_nj(nfeats)
f_fij = self.fc_fij(efeats)
graph.srcdata.update({'f_ni': f_ni})
graph.dstdata.update({'f_nj': f_nj})
# add ni, nj factors
graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
# add fij to node factor
f_out = graph.edata.pop('f_tmp') + f_fij
if self.bias is not None:
f_out = f_out + self.bias
f_out = nn.functional.leaky_relu(f_out)
f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e)
graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads,
self._out_node_feats)
# calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
fn.sum('m', 'h_out'))
h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats)
if get_attention:
return h_out, f_out, graph.edata.pop('a')
else:
return h_out, f_out
...@@ -564,6 +564,28 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): ...@@ -564,6 +564,28 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
egat = nn.EGATConv(in_node_feats=10,
in_edge_feats=5,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads)
nfeat = F.randn((g.number_of_nodes(), 10))
efeat = F.randn((g.number_of_edges(), 5))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)
h, f, attn = egat(g, nfeat, efeat, True)
th.save(egat, tmp_buffer)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
...@@ -1137,6 +1159,7 @@ if __name__ == '__main__': ...@@ -1137,6 +1159,7 @@ if __name__ == '__main__':
test_rgcn_sorted() test_rgcn_sorted()
test_tagconv() test_tagconv()
test_gat_conv() test_gat_conv()
test_egat_conv()
test_sage_conv() test_sage_conv()
test_sgc_conv() test_sgc_conv()
test_appnp_conv() test_appnp_conv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment