Commit 1f0f57b3 authored by Minjie Wang's avatar Minjie Wang
Browse files

fix bug in treelstm

parent cdf7334c
...@@ -91,7 +91,7 @@ def main(args): ...@@ -91,7 +91,7 @@ def main(args):
mean_dur = np.mean(dur) mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | " print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format( "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item() / args.batch_size, epoch, step, loss.item(), acc.item() / len(label),
mean_dur, args.batch_size / mean_dur)) mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch) print("Epoch time(s):", time.time() - t_epoch)
......
...@@ -23,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -23,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module):
self.ut = 0. self.ut = 0.
def message_func(self, src, edge): def message_func(self, src, edge):
return src return {'h' : src['h'], 'c' : src['c']}
def reduce_func(self, node, msgs): def reduce_func(self, node, msgs):
# equation (2) # equation (2)
...@@ -90,9 +90,9 @@ class TreeLSTM(nn.Module): ...@@ -90,9 +90,9 @@ class TreeLSTM(nn.Module):
""" """
g = graph g = graph
n = g.number_of_nodes() n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True) g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func, batchable=True) g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func, batchable=True) g.register_apply_node_func(self.cell.apply_func)
# feed embedding # feed embedding
wordid = g.pop_n_repr('x') wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD) mask = (wordid != dgl.data.SST.PAD_WORD)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment