Unverified Commit 9c106547 authored by RONANKI SWAMY SRIHARSHA's avatar RONANKI SWAMY SRIHARSHA Committed by GitHub
Browse files

[NN] Added activation function as an optional parameter to GINConv (#3565)



* Added activation function as an optional parameter

* lint fixes

* Modified the input parameters in tandem with other classes

* lint corrections

* corrected tests

* Reverting back to the old interface

* lint corrections
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 4889c578
...@@ -41,9 +41,12 @@ class GINConv(nn.Module): ...@@ -41,9 +41,12 @@ class GINConv(nn.Module):
Initial :math:`\epsilon` value, default: ``0``. Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``. If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Example Examples
------- --------
>>> import dgl >>> import dgl
>>> import numpy as np >>> import numpy as np
>>> import torch as th >>> import torch as th
...@@ -67,15 +70,35 @@ class GINConv(nn.Module): ...@@ -67,15 +70,35 @@ class GINConv(nn.Module):
0.8843, -0.8764], 0.8843, -0.8764],
[-0.1804, 0.0758, -0.5159, 0.3569, -0.1408, -0.1395, -0.2387, 0.7773, [-0.1804, 0.0758, -0.5159, 0.3569, -0.1408, -0.1395, -0.2387, 0.7773,
0.5266, -0.4465]], grad_fn=<AddmmBackward>) 0.5266, -0.4465]], grad_fn=<AddmmBackward>)
>>> # With activation
>>> from torch.nn.functional import relu
>>> conv = GINConv(lin, 'max', activation=relu)
>>> res = conv(g, feat)
>>> res
tensor([[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
""" """
def __init__(self, def __init__(self,
apply_func, apply_func,
aggregator_type, aggregator_type,
init_eps=0, init_eps=0,
learn_eps=False): learn_eps=False,
activation=None):
super(GINConv, self).__init__() super(GINConv, self).__init__()
self.apply_func = apply_func self.apply_func = apply_func
self._aggregator_type = aggregator_type self._aggregator_type = aggregator_type
self.activation = activation
if aggregator_type not in ('sum', 'max', 'mean'): if aggregator_type not in ('sum', 'max', 'mean'):
raise KeyError( raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type)) 'Aggregator type {} not recognized.'.format(aggregator_type))
...@@ -85,6 +108,22 @@ class GINConv(nn.Module): ...@@ -85,6 +108,22 @@ class GINConv(nn.Module):
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer('eps', th.FloatTensor([init_eps]))
self.reset_parameters()
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.apply_func.weight, gain=gain)
def forward(self, graph, feat, edge_weight=None): def forward(self, graph, feat, edge_weight=None):
r""" r"""
...@@ -129,4 +168,7 @@ class GINConv(nn.Module): ...@@ -129,4 +168,7 @@ class GINConv(nn.Module):
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
# activation
if self.activation is not None:
rst = self.activation(rst)
return rst return rst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment