"...libxsmm/src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c454d419cc5e036daaf8ebf73ccb82fa751a5cd0"
Unverified Commit d6957c28 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] fix incorrect _bias and bias usage (#4310)

parent b3242e90
...@@ -151,7 +151,7 @@ class GCN2Conv(nn.Module): ...@@ -151,7 +151,7 @@ class GCN2Conv(nn.Module):
nn.init.normal_(self.weight1) nn.init.normal_(self.weight1)
if not self._project_initial_features: if not self._project_initial_features:
nn.init.normal_(self.weight2) nn.init.normal_(self.weight2)
if self._bias is not None: if self._bias:
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
...@@ -265,8 +265,8 @@ class GCN2Conv(nn.Module): ...@@ -265,8 +265,8 @@ class GCN2Conv(nn.Module):
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
) )
if self._bias is not None: if self._bias:
rst = rst + self._bias rst = rst + self.bias
if self._activation is not None: if self._activation is not None:
rst = self._activation(rst) rst = self._activation(rst)
......
...@@ -647,10 +647,11 @@ def test_appnp_conv_e_weight(g, idtype): ...@@ -647,10 +647,11 @@ def test_appnp_conv_e_weight(g, idtype):
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype): @pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx) g = g.astype(idtype).to(ctx)
gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5, gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5, bias=bias,
project_initial_features=True) project_initial_features=True)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), )) eweight = F.ones((g.num_edges(), ))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment