"docs/vscode:/vscode.git/clone" did not exist on "a244de579cbd964912aa46446a0a9b2a2f595c87"
Unverified Commit 05c6c3c5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Fix GINConv (#3692)



* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* Update ginconv.py

* Update test_nn.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-31-136.us-west-2.compute.internal>
parent c8fef629
......@@ -34,9 +34,9 @@ class GINConv(nn.Module):
----------
apply_func : callable activation function/layer or None
If not None, apply this function to the updated node feature,
the :math:`f_\Theta` in the formula.
the :math:`f_\Theta` in the formula, default: None.
aggregator_type : str
Aggregator type to use (``sum``, ``max`` or ``mean``).
Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
......@@ -90,8 +90,8 @@ class GINConv(nn.Module):
0.0000]], grad_fn=<ReluBackward0>)
"""
def __init__(self,
apply_func,
aggregator_type,
apply_func=None,
aggregator_type='sum',
init_eps=0,
learn_eps=False,
activation=None):
......@@ -108,22 +108,6 @@ class GINConv(nn.Module):
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self.reset_parameters()
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.apply_func.weight, gain=gain)
def forward(self, graph, feat, edge_weight=None):
r"""
......
......@@ -788,6 +788,11 @@ def test_gin_conv(g, idtype, aggregator_type):
th.save(gin, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12)
gin = nn.GINConv(None, aggregator_type)
th.save(gin, tmp_buffer)
gin = gin.to(ctx)
h = gin(g, feat)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
......@@ -1383,4 +1388,4 @@ if __name__ == '__main__':
test_atomic_conv()
test_cf_conv()
test_hetero_conv()
test_twirls()
\ No newline at end of file
test_twirls()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment