"src/runtime/runtime_base.h" did not exist on "61fa3c6cf5441e6dd34e3d51e5f519308e5a1baf"
Commit 8c71f3f8 authored by Minjie Wang's avatar Minjie Wang
Browse files

Fix treelstm CPU

parent b7eb1659
......@@ -67,13 +67,16 @@ def main(args):
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
if cuda:
batch = _batch_to_cuda(batch)
g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size)).cuda()
h = th.zeros((n, args.h_size)).cuda()
c = th.zeros((n, args.h_size)).cuda()
x = th.zeros((n, args.x_size))
h = th.zeros((n, args.h_size))
c = th.zeros((n, args.h_size))
if cuda:
batch = _batch_to_cuda(batch)
x = x.cuda()
h = h.cuda()
c = c.cuda()
if step >= 3:
t0 = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment