Commit 688a9228 authored by Da Zheng's avatar Da Zheng Committed by Sheng Zha
Browse files

fix. (#491)

parent e2e04329
......@@ -122,7 +122,7 @@ def main(args):
dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0:
pred = pred.argmax(axis=1)
pred = pred.argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum()
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
......@@ -139,7 +139,7 @@ def main(args):
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(1)
pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)
acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
......@@ -175,7 +175,7 @@ def main(args):
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(axis=1)
pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
......
......@@ -118,7 +118,8 @@ class TreeLSTM(gluon.nn.Block):
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1)
wiou = self.cell.W_iou(self.dropout(embeds))
g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment