"""Graph Convolutional Networks.""" import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch import GraphConv __all__ = ['GCN'] class GCNLayer(nn.Module): r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks `__ Parameters ---------- in_feats : int Number of input node features. out_feats : int Number of output node features. activation : activation function Default to be None. residual : bool Whether to use residual connection, default to be True. batchnorm : bool Whether to use batch normalization on the output, default to be True. dropout : float The probability for dropout. Default to be 0., i.e. no dropout is performed. """ def __init__(self, in_feats, out_feats, activation=None, residual=True, batchnorm=True, dropout=0.): super(GCNLayer, self).__init__() self.activation = activation self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats, norm=False, activation=activation) self.dropout = nn.Dropout(dropout) self.residual = residual if residual: self.res_connection = nn.Linear(in_feats, out_feats) self.bn = batchnorm if batchnorm: self.bn_layer = nn.BatchNorm1d(out_feats) def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which must match in_feats in initialization Returns ------- new_feats : FloatTensor of shape (N, M2) * M2 is the output node feature size, which must match out_feats in initialization """ new_feats = self.graph_conv(g, feats) if self.residual: res_feats = self.activation(self.res_connection(feats)) new_feats = new_feats + res_feats new_feats = self.dropout(new_feats) if self.bn: new_feats = self.bn_layer(new_feats) return new_feats class GCN(nn.Module): r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks `__ Parameters ---------- in_feats : int Number of input node features. hidden_feats : list of int ``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer. ``len(hidden_feats)`` equals the number of GCN layers. By default, we use ``[64, 64]``. activation : list of activation functions or None If None, no activation will be applied. If not None, ``activation[i]`` gives the activation function to be used for the i-th GCN layer. ``len(activation)`` equals the number of GCN layers. By default, ReLU is applied for all GCN layers. residual : list of bool ``residual[i]`` decides if residual connection is to be used for the i-th GCN layer. ``len(residual)`` equals the number of GCN layers. By default, residual connection is performed for each GCN layer. batchnorm : list of bool ``batchnorm[i]`` decides if batch normalization is to be applied on the output of the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default, batch normalization is applied for all GCN layers. dropout : list of float ``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer. ``len(dropout)`` equals the number of GCN layers. By default, no dropout is performed for all layers. """ def __init__(self, in_feats, hidden_feats=None, activation=None, residual=None, batchnorm=None, dropout=None): super(GCN, self).__init__() if hidden_feats is None: hidden_feats = [64, 64] n_layers = len(hidden_feats) if activation is None: activation = [F.relu for _ in range(n_layers)] if residual is None: residual = [True for _ in range(n_layers)] if batchnorm is None: batchnorm = [True for _ in range(n_layers)] if dropout is None: dropout = [0. for _ in range(n_layers)] lengths = [len(hidden_feats), len(activation), len(residual), len(batchnorm), len(dropout)] assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, activation, ' \ 'residual, batchnorm and dropout to be the same, ' \ 'got {}'.format(lengths) self.hidden_feats = hidden_feats self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], activation[i], residual[i], batchnorm[i], dropout[i])) in_feats = hidden_feats[i] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] in initialization. """ for gnn in self.gnn_layers: feats = gnn(g, feats) return feats