Unverified Commit d6957c28 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] fix incorrect _bias and bias usage (#4310)

parent b3242e90
......@@ -151,7 +151,7 @@ class GCN2Conv(nn.Module):
nn.init.normal_(self.weight1)
if not self._project_initial_features:
nn.init.normal_(self.weight2)
if self._bias is not None:
if self._bias:
nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
......@@ -265,8 +265,8 @@ class GCN2Conv(nn.Module):
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
)
if self._bias is not None:
rst = rst + self._bias
if self._bias:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
......
......@@ -647,10 +647,11 @@ def test_appnp_conv_e_weight(g, idtype):
@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype):
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5,
gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5, bias=bias,
project_initial_features=True)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment