Commit 7aa494b3 authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[Bugfix][Doc] explain the batch dimension in the doc and many fix (#266)

* add more unit tests for mxnet.

* fix.

* explain the batch dimension.

* update doc.

* disable unit tests on DFS.

* fix graph traversal.
parent ac660f45
......@@ -15,6 +15,10 @@ There are two types of user-defined functions in DGL:
a batch of edges. The returned dictionary should have ``str`` type key and ``tensor``
type values.
Note: the size of the batch dimension is determined by the DGL framework
for good efficiency and small memory footprint. Users should not make
assumption in the batch dimension.
EdgeBatch
---------
......@@ -28,6 +32,7 @@ The class that can represent a batch of edges.
EdgeBatch.data
EdgeBatch.edges
EdgeBatch.batch_size
EdgeBatch.__len__
NodeBatch
---------
......@@ -41,3 +46,4 @@ The class that can represent a batch of nodes.
NodeBatch.mailbox
NodeBatch.nodes
NodeBatch.batch_size
NodeBatch.__len__
......@@ -91,7 +91,12 @@ class EdgeBatch(object):
return len(self._edges[0])
def __len__(self):
"""Return the number of edges in this edge batch."""
"""Return the number of edges in this edge batch.
Returns
-------
int
"""
return self.batch_size()
class NodeBatch(object):
......@@ -167,5 +172,10 @@ class NodeBatch(object):
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch."""
"""Return the number of nodes in this node batch.
Returns
-------
int
"""
return self.batch_size()
......@@ -202,7 +202,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1];
const IdArray source = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
......@@ -221,7 +221,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1];
const IdArray source = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4];
......
......@@ -69,5 +69,6 @@ def test_prop_nodes_topo():
if __name__ == '__main__':
test_prop_nodes_bfs()
test_prop_edges_dfs()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_prop_edges_dfs()
test_prop_nodes_topo()
......@@ -123,4 +123,5 @@ def test_dfs_labeled_edges(n=1000, example=False):
if __name__ == '__main__':
test_bfs()
test_topological_nodes()
test_dfs_labeled_edges()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_dfs_labeled_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment