Unverified Commit 05c6c3c5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Fix GINConv (#3692)



* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* Update ginconv.py

* Update test_nn.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-31-136.us-west-2.compute.internal>
parent c8fef629
...@@ -34,9 +34,9 @@ class GINConv(nn.Module): ...@@ -34,9 +34,9 @@ class GINConv(nn.Module):
---------- ----------
apply_func : callable activation function/layer or None apply_func : callable activation function/layer or None
If not None, apply this function to the updated node feature, If not None, apply this function to the updated node feature,
the :math:`f_\Theta` in the formula. the :math:`f_\Theta` in the formula, default: None.
aggregator_type : str aggregator_type : str
Aggregator type to use (``sum``, ``max`` or ``mean``). Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.
init_eps : float, optional init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``. Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional learn_eps : bool, optional
...@@ -90,8 +90,8 @@ class GINConv(nn.Module): ...@@ -90,8 +90,8 @@ class GINConv(nn.Module):
0.0000]], grad_fn=<ReluBackward0>) 0.0000]], grad_fn=<ReluBackward0>)
""" """
def __init__(self, def __init__(self,
apply_func, apply_func=None,
aggregator_type, aggregator_type='sum',
init_eps=0, init_eps=0,
learn_eps=False, learn_eps=False,
activation=None): activation=None):
...@@ -108,22 +108,6 @@ class GINConv(nn.Module): ...@@ -108,22 +108,6 @@ class GINConv(nn.Module):
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer('eps', th.FloatTensor([init_eps]))
self.reset_parameters()
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.apply_func.weight, gain=gain)
def forward(self, graph, feat, edge_weight=None): def forward(self, graph, feat, edge_weight=None):
r""" r"""
......
...@@ -789,6 +789,11 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -789,6 +789,11 @@ def test_gin_conv(g, idtype, aggregator_type):
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
gin = nn.GINConv(None, aggregator_type)
th.save(gin, tmp_buffer)
gin = gin.to(ctx)
h = gin(g, feat)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum']) @pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment