Unverified Commit 3564fdc5 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix][Model] fix treelstm model (#274)

* fix bug after moving batcher out of dgl.data

* disable mx utest
parent f491d6b9
import argparse
import collections
import time
import numpy as np
import torch as th
......@@ -12,7 +13,8 @@ from dgl.data.tree import SST
from tree_lstm import TreeLSTM
def batcher(dev):
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
......
......@@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs():
def _test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
......@@ -70,5 +70,5 @@ def test_prop_nodes_topo():
if __name__ == '__main__':
test_prop_nodes_bfs()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_prop_edges_dfs()
#_test_prop_edges_dfs()
test_prop_nodes_topo()
......@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000):
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def test_dfs_labeled_edges(n=1000, example=False):
def _test_dfs_labeled_edges(n=1000, example=False):
dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
......@@ -124,4 +124,4 @@ if __name__ == '__main__':
test_bfs()
test_topological_nodes()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_dfs_labeled_edges()
#_test_dfs_labeled_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment