"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e222246b4e7b60db7fe5fd27dc187bce446b5b56"
Commit 7aa494b3 authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[Bugfix][Doc] explain the batch dimension in the doc and many fix (#266)

* add more unit tests for mxnet.

* fix.

* explain the batch dimension.

* update doc.

* disable unit tests on DFS.

* fix graph traversal.
parent ac660f45
...@@ -15,6 +15,10 @@ There are two types of user-defined functions in DGL: ...@@ -15,6 +15,10 @@ There are two types of user-defined functions in DGL:
a batch of edges. The returned dictionary should have ``str`` type key and ``tensor`` a batch of edges. The returned dictionary should have ``str`` type key and ``tensor``
type values. type values.
Note: the size of the batch dimension is determined by the DGL framework
for good efficiency and small memory footprint. Users should not make
assumption in the batch dimension.
EdgeBatch EdgeBatch
--------- ---------
...@@ -28,6 +32,7 @@ The class that can represent a batch of edges. ...@@ -28,6 +32,7 @@ The class that can represent a batch of edges.
EdgeBatch.data EdgeBatch.data
EdgeBatch.edges EdgeBatch.edges
EdgeBatch.batch_size EdgeBatch.batch_size
EdgeBatch.__len__
NodeBatch NodeBatch
--------- ---------
...@@ -41,3 +46,4 @@ The class that can represent a batch of nodes. ...@@ -41,3 +46,4 @@ The class that can represent a batch of nodes.
NodeBatch.mailbox NodeBatch.mailbox
NodeBatch.nodes NodeBatch.nodes
NodeBatch.batch_size NodeBatch.batch_size
NodeBatch.__len__
...@@ -91,7 +91,12 @@ class EdgeBatch(object): ...@@ -91,7 +91,12 @@ class EdgeBatch(object):
return len(self._edges[0]) return len(self._edges[0])
def __len__(self): def __len__(self):
"""Return the number of edges in this edge batch.""" """Return the number of edges in this edge batch.
Returns
-------
int
"""
return self.batch_size() return self.batch_size()
class NodeBatch(object): class NodeBatch(object):
...@@ -167,5 +172,10 @@ class NodeBatch(object): ...@@ -167,5 +172,10 @@ class NodeBatch(object):
return len(self._nodes) return len(self._nodes)
def __len__(self): def __len__(self):
"""Return the number of nodes in this node batch.""" """Return the number of nodes in this node batch.
Returns
-------
int
"""
return self.batch_size() return self.batch_size()
...@@ -202,7 +202,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -202,7 +202,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
...@@ -221,7 +221,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") ...@@ -221,7 +221,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const bool reversed = args[2]; const bool reversed = args[2];
const bool has_reverse_edge = args[3]; const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
......
...@@ -69,5 +69,6 @@ def test_prop_nodes_topo(): ...@@ -69,5 +69,6 @@ def test_prop_nodes_topo():
if __name__ == '__main__': if __name__ == '__main__':
test_prop_nodes_bfs() test_prop_nodes_bfs()
test_prop_edges_dfs() #TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_prop_edges_dfs()
test_prop_nodes_topo() test_prop_nodes_topo()
...@@ -123,4 +123,5 @@ def test_dfs_labeled_edges(n=1000, example=False): ...@@ -123,4 +123,5 @@ def test_dfs_labeled_edges(n=1000, example=False):
if __name__ == '__main__': if __name__ == '__main__':
test_bfs() test_bfs()
test_topological_nodes() test_topological_nodes()
test_dfs_labeled_edges() #TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_dfs_labeled_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment