Unverified Commit 3564fdc5 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix][Model] fix treelstm model (#274)

* fix bug after moving batcher out of dgl.data

* disable mx utest
parent f491d6b9
import argparse import argparse
import collections
import time import time
import numpy as np import numpy as np
import torch as th import torch as th
...@@ -12,7 +13,8 @@ from dgl.data.tree import SST ...@@ -12,7 +13,8 @@ from dgl.data.tree import SST
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
def batcher(dev): SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(device):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees, return SSTBatch(graph=batch_trees,
......
...@@ -23,7 +23,7 @@ def test_prop_nodes_bfs(): ...@@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
assert np.allclose(g.ndata['x'].asnumpy(), assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]])) np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs(): def _test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5)) g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc) g.register_message_func(mfunc)
g.register_reduce_func(rfunc) g.register_reduce_func(rfunc)
...@@ -70,5 +70,5 @@ def test_prop_nodes_topo(): ...@@ -70,5 +70,5 @@ def test_prop_nodes_topo():
if __name__ == '__main__': if __name__ == '__main__':
test_prop_nodes_bfs() test_prop_nodes_bfs()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04. #TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_prop_edges_dfs() #_test_prop_edges_dfs()
test_prop_nodes_topo() test_prop_nodes_topo()
...@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000): ...@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000):
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv)) assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree'] DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def test_dfs_labeled_edges(n=1000, example=False): def _test_dfs_labeled_edges(n=1000, example=False):
dgl_g = dgl.DGLGraph() dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6) dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5]) dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
...@@ -124,4 +124,4 @@ if __name__ == '__main__': ...@@ -124,4 +124,4 @@ if __name__ == '__main__':
test_bfs() test_bfs()
test_topological_nodes() test_topological_nodes()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04. #TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_dfs_labeled_edges() #_test_dfs_labeled_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment