"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67e2f95cc4ff8c25f4d04f8bab46df02216527b2"
Commit 0d0f4436 authored by Minjie Wang's avatar Minjie Wang Committed by Da Zheng
Browse files

[Bugfix] MX utest traversal memory corruption (#312)

* WIP

* temp fix mx traversal memory crash bug
parent 3d446301
......@@ -40,8 +40,8 @@ def bfs_nodes_generator(graph, source, reversed=False):
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLBFSNodes(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLBFSNodes(ghandle, source.todgltensor(), reversed)
all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
......@@ -80,8 +80,8 @@ def bfs_edges_generator(graph, source, reversed=False):
[tensor([0]), tensor([1, 2]), tensor([4, 5])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLBFSEdges(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLBFSEdges(ghandle, source.todgltensor(), reversed)
all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
......@@ -161,8 +161,8 @@ def dfs_edges_generator(graph, source, reversed=False):
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLDFSEdges(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLDFSEdges(ghandle, source.todgltensor(), reversed)
all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
......@@ -232,10 +232,10 @@ def dfs_labeled_edges_generator(
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
source = utils.toindex(source)
ret = _CAPI_DGLDFSLabeledEdges(
ghandle,
source,
source.todgltensor(),
reversed,
has_reverse_edge,
has_nontree_edge,
......
......@@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def _test_prop_edges_dfs():
def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
......@@ -69,6 +69,5 @@ def test_prop_nodes_topo():
if __name__ == '__main__':
test_prop_nodes_bfs()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#_test_prop_edges_dfs()
test_prop_edges_dfs()
test_prop_nodes_topo()
......@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000):
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def _test_dfs_labeled_edges(n=1000, example=False):
def test_dfs_labeled_edges(n=1000, example=False):
dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
......@@ -123,5 +123,4 @@ def _test_dfs_labeled_edges(n=1000, example=False):
if __name__ == '__main__':
test_bfs()
test_topological_nodes()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#_test_dfs_labeled_edges()
test_dfs_labeled_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment