Commit 1f0f57b3 authored by Minjie Wang's avatar Minjie Wang
Browse files

fix bug in treelstm

parent cdf7334c
......@@ -91,7 +91,7 @@ def main(args):
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item() / args.batch_size,
epoch, step, loss.item(), acc.item() / len(label),
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
......
......@@ -23,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module):
self.ut = 0.
def message_func(self, src, edge):
return src
return {'h' : src['h'], 'c' : src['c']}
def reduce_func(self, node, msgs):
# equation (2)
......@@ -90,9 +90,9 @@ class TreeLSTM(nn.Module):
"""
g = graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True)
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func)
# feed embedding
wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment