"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cb63febf2ee996e2132540e119923f50780eae06"
Unverified Commit 7a80faf1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Bugfix] Correct the loss function in RGCN model (#1217)



* upd

* upd

* fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 032a08cc
...@@ -34,7 +34,7 @@ class EntityClassify(BaseRGCN): ...@@ -34,7 +34,7 @@ class EntityClassify(BaseRGCN):
def build_output_layer(self): def build_output_layer(self):
return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=partial(F.softmax, axis=1), self.num_bases, activation=None,
self_loop=self.use_self_loop) self_loop=self.use_self_loop)
def main(args): def main(args):
......
...@@ -239,7 +239,7 @@ class EntityClassify(nn.Module): ...@@ -239,7 +239,7 @@ class EntityClassify(nn.Module):
# h2o # h2o
self.layers.append(RelGraphConvHetero( self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis", self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=partial(F.softmax, dim=1), self.num_bases, activation=None,
self_loop=self.use_self_loop)) self_loop=self.use_self_loop))
def forward(self): def forward(self):
......
...@@ -39,7 +39,7 @@ class EntityClassify(BaseRGCN): ...@@ -39,7 +39,7 @@ class EntityClassify(BaseRGCN):
def build_output_layer(self): def build_output_layer(self):
return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=partial(F.softmax, dim=1), self.num_bases, activation=None,
self_loop=self.use_self_loop) self_loop=self.use_self_loop)
def main(args): def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment