Unverified Commit 7a80faf1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Bugfix] Correct the loss function in RGCN model (#1217)



* upd

* upd

* fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 032a08cc
......@@ -34,7 +34,7 @@ class EntityClassify(BaseRGCN):
def build_output_layer(self):
return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=partial(F.softmax, axis=1),
self.num_bases, activation=None,
self_loop=self.use_self_loop)
def main(args):
......
......@@ -239,7 +239,7 @@ class EntityClassify(nn.Module):
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=partial(F.softmax, dim=1),
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self):
......
......@@ -39,7 +39,7 @@ class EntityClassify(BaseRGCN):
def build_output_layer(self):
return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=partial(F.softmax, dim=1),
self.num_bases, activation=None,
self_loop=self.use_self_loop)
def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment