Unverified Commit c4c9b830 authored by paoxiaode's avatar paoxiaode Committed by GitHub
Browse files

[NN] GatedGCNConv (#5654)

parent 4c015632
......@@ -26,6 +26,7 @@ Conv Layers
~dgl.nn.pytorch.conv.GINConv
~dgl.nn.pytorch.conv.GINEConv
~dgl.nn.pytorch.conv.GatedGraphConv
~dgl.nn.pytorch.conv.GatedGCNConv
~dgl.nn.pytorch.conv.GMMConv
~dgl.nn.pytorch.conv.ChebConv
~dgl.nn.pytorch.conv.AGNNConv
......
......@@ -19,6 +19,7 @@ from .edgegatconv import EdgeGATConv
from .egatconv import EGATConv
from .egnnconv import EGNNConv
from .gatconv import GATConv
from .gatedgcnconv import GatedGCNConv
from .gatedgraphconv import GatedGraphConv
from .gatv2conv import GATv2Conv
from .gcn2conv import GCN2Conv
......@@ -51,6 +52,7 @@ __all__ = [
"GINConv",
"GINEConv",
"GatedGraphConv",
"GatedGCNConv",
"GMMConv",
"ChebConv",
"AGNNConv",
......
"""Torch Module for GatedGCN layer"""
# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop
import torch
import torch.nn.functional as F
from torch import nn
from .... import function as fn
class GatedGCNConv(nn.Module):
r"""Gated graph convolutional layer from `Benchmarking Graph Neural Networks
<https://arxiv.org/abs/2003.00982>`__
.. math::
e_{ij}^{l+1}=D^l h_{i}^{l}+E^l h_{j}^{l}+C^l e_{ij}^{l}
norm_{ij}=\Sigma_{j\in N_{i}} \sigma\left(e_{ij}^{l+1}\right)+\varepsilon
\hat{e}_{ij}^{l+1}=\sigma(e_{ij}^{l+1}) / norm_{ij}
h_{i}^{l+1}=A^l h_{i}^{l}+\Sigma_{j \in N_{i}} \hat{e}_{ij}^{l+1} \odot B^l h_{j}^{l}
where :math:`h_{i}^{l}` is node :math:`i` feature of layer :math:`l`,
:math:`e_{ij}^{l}` is edge :math:`ij` feature of layer :math:`l`,
:math:`\sigma` is sigmoid function, :math:`\varepsilon` is a small fixed constant
for numerical stability, :math:`A^l, B^l, C^l, D^l, E^l` are linear layers.
Parameters
----------
input_feats : int
Input feature size; i.e, the number of dimensions of :math:`h_{i}^{l}`.
edge_feats: int
Edge feature size; i.e., the number of dimensions of :math:`e_{ij}^{l}`.
output_feats : int
Output feature size; i.e., the number of dimensions of :math:`h_{i}^{l+1}`.
dropout : float, optional
Dropout rate on node and edge feature. Default: ``0``.
batch_norm : bool, optional
Whether to include batch normalization on node and edge feature. Default: ``True``.
residual : bool, optional
Whether to include residual connections. Default: ``True``.
activation : callable activation function/layer or None, optional
If not None, apply an activation function to the updated node features.
Default: ``F.relu``.
Example
-------
>>> import dgl
>>> import torch as th
>>> import torch.nn.functional as F
>>> from dgl.nn import GatedGCNConv
>>> num_nodes, num_edges = 8, 30
>>> graph = dgl.rand_graph(num_nodes,num_edges)
>>> node_feats = th.rand(num_nodes, 20)
>>> edge_feats = th.rand(num_edges, 12)
>>> gatedGCN = GatedGCNConv(20, 12, 20)
>>> new_node_feats, new_edge_feats = gatedGCN(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
(torch.Size([8, 20]), torch.Size([30, 20]))
"""
def __init__(
self,
input_feats,
edge_feats,
output_feats,
dropout=0,
batch_norm=True,
residual=True,
activation=F.relu,
):
super(GatedGCNConv, self).__init__()
self.dropout = nn.Dropout(dropout)
self.batch_norm = batch_norm
self.residual = residual
if input_feats != output_feats or edge_feats != output_feats:
self.residual = False
# Linearly transform the node features.
self.A = nn.Linear(input_feats, output_feats, bias=True)
self.B = nn.Linear(input_feats, output_feats, bias=True)
self.D = nn.Linear(input_feats, output_feats, bias=True)
self.E = nn.Linear(input_feats, output_feats, bias=True)
# Linearly transform the edge features.
self.C = nn.Linear(edge_feats, output_feats, bias=True)
# Batch normalization on the node/edge features.
self.bn_node = nn.BatchNorm1d(output_feats)
self.bn_edge = nn.BatchNorm1d(output_feats)
self.activation = activation
def forward(self, graph, feat, edge_feat):
"""
Description
-----------
Compute gated graph convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, D_{edge})`,
where :math:`E` is the number of edges and :math:`D_{edge}`
is the size of the edge features.
Returns
-------
torch.Tensor
The output node feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
torch.Tensor
The output edge feature of shape :math:`(E, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
with graph.local_scope():
# For residual connection
h_in = feat
e_in = edge_feat
graph.ndata["Ah"] = self.A(feat)
graph.ndata["Bh"] = self.B(feat)
graph.ndata["Dh"] = self.D(feat)
graph.ndata["Eh"] = self.E(feat)
graph.edata["Ce"] = self.C(edge_feat)
graph.apply_edges(fn.u_add_v("Dh", "Eh", "DEh"))
# Get edge feature
graph.edata["e"] = graph.edata["DEh"] + graph.edata["Ce"]
graph.edata["sigma"] = torch.sigmoid(graph.edata["e"])
graph.update_all(
fn.u_mul_e("Bh", "sigma", "m"), fn.sum("m", "sum_sigma_h")
)
graph.update_all(fn.copy_e("sigma", "m"), fn.sum("m", "sum_sigma"))
graph.ndata["h"] = graph.ndata["Ah"] + graph.ndata[
"sum_sigma_h"
] / (graph.ndata["sum_sigma"] + 1e-6)
# Result of graph convolution.
feat = graph.ndata["h"]
edge_feat = graph.edata["e"]
# Batch normalization.
if self.batch_norm:
feat = self.bn_node(feat)
edge_feat = self.bn_edge(edge_feat)
# Non-linear activation.
if self.activation:
feat = self.activation(feat)
edge_feat = self.activation(edge_feat)
# Residual connection.
if self.residual:
feat = h_in + feat
edge_feat = e_in + edge_feat
feat = self.dropout(feat)
edge_feat = self.dropout(edge_feat)
return feat, edge_feat
import io
import backend as F
import dgl.nn.pytorch as nn
import pytest
from utils import parametrize_idtype
from utils.graph_cases import get_cases
tmp_buffer = io.BytesIO()
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
def test_gatedgcn_conv(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gatedgcnconv = nn.GatedGCNConv(10, 10, 5)
feat = F.randn((g.num_nodes(), 10))
efeat = F.randn((g.num_edges(), 10))
gatedgcnconv = gatedgcnconv.to(ctx)
h, edge_h = gatedgcnconv(g, feat, efeat)
# current we only do shape check
assert h.shape == (g.number_of_dst_nodes(), 5)
assert edge_h.shape == (g.number_of_edges(), 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment