"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "24563ca654f6574dae93aeece8eeef69e39097e5"
Commit 688a9228 authored by Da Zheng's avatar Da Zheng Committed by Sheng Zha
Browse files

fix. (#491)

parent e2e04329
...@@ -122,7 +122,7 @@ def main(args): ...@@ -122,7 +122,7 @@ def main(args):
dur.append(time.time() - t0) # tok dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = pred.argmax(axis=1) pred = pred.argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum() acc = (batch.label == pred).sum()
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
...@@ -139,7 +139,7 @@ def main(args): ...@@ -139,7 +139,7 @@ def main(args):
n = g.number_of_nodes() n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx) h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx) c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(1) pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)
acc = (batch.label == pred).sum().asscalar() acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
...@@ -175,7 +175,7 @@ def main(args): ...@@ -175,7 +175,7 @@ def main(args):
n = g.number_of_nodes() n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx) h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx) c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(axis=1) pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum().asscalar() acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
......
...@@ -118,7 +118,8 @@ class TreeLSTM(gluon.nn.Block): ...@@ -118,7 +118,8 @@ class TreeLSTM(gluon.nn.Block):
g.register_apply_node_func(self.cell.apply_node_func) g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1) wiou = self.cell.W_iou(self.dropout(embeds))
g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment