"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "e4edf27a97a8580cf50b636ad80a20455c26ed75"
Commit 485f6d3a authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Doc] udf (#195)

* Docstring for udf

* Track udf docs

* Improve

* Improved

* Delete udf.py

* Improve
parent 9e826173
...@@ -9,3 +9,4 @@ API Reference ...@@ -9,3 +9,4 @@ API Reference
function function
traversal traversal
propagate propagate
udf
.. _apiudf:
User-defined function related data structures
==================================================
.. currentmodule:: dgl.udf
EdgeBatch
---------
The class that can represent a batch of edges.
.. autosummary::
:toctree: ../../generated/
EdgeBatch.src
EdgeBatch.dst
EdgeBatch.data
EdgeBatch.edges
EdgeBatch.batch_size
NodeBatch
---------
The class that can represent a batch of nodes.
.. autosummary::
:toctree: ../../generated/
NodeBatch.data
NodeBatch.mailbox
NodeBatch.nodes
NodeBatch.batch_size
"""User-defined function related data structures.""" """User-defined function related data structures."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import Mapping
from .base import ALL, is_all from .base import ALL, is_all
from . import backend as F from . import backend as F
from . import utils from . import utils
class EdgeBatch(object): class EdgeBatch(object):
"""The object that represents a batch of edges. """The class that can represent a batch of edges.
Parameters Parameters
---------- ----------
...@@ -16,12 +14,15 @@ class EdgeBatch(object): ...@@ -16,12 +14,15 @@ class EdgeBatch(object):
The graph object. The graph object.
edges : tuple of utils.Index edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL The edge tuple (u, v, eid). eid can be ALL
src_data : dict of tensors src_data : dict
The src node features The src node features, in the form of ``dict``
edge_data : dict of tensors with ``str`` keys and ``tensor`` values
The edge features. edge_data : dict
The edge features, in the form of ``dict`` with
``str`` keys and ``tensor`` values
dst_data : dict of tensors dst_data : dict of tensors
The dst node features The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, edges, src_data, edge_data, dst_data): def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g self._g = g
...@@ -36,8 +37,8 @@ class EdgeBatch(object): ...@@ -36,8 +37,8 @@ class EdgeBatch(object):
Returns Returns
------- -------
dict of str to tensors dict with str keys and tensor values
The feature data. Features of the source nodes.
""" """
return self._src_data return self._src_data
...@@ -47,8 +48,8 @@ class EdgeBatch(object): ...@@ -47,8 +48,8 @@ class EdgeBatch(object):
Returns Returns
------- -------
dict of str to tensors dict with str keys and tensor values
The feature data. Features of the destination nodes.
""" """
return self._dst_data return self._dst_data
...@@ -58,18 +59,21 @@ class EdgeBatch(object): ...@@ -58,18 +59,21 @@ class EdgeBatch(object):
Returns Returns
------- -------
dict of str to tensors dict with str keys and tensor values
The feature data. Features of the edges.
""" """
return self._edge_data return self._edge_data
def edges(self): def edges(self):
"""Return the edges contained in this batch. """Return the edges contained in this batch.
Returns Returns
------- -------
tuple of tensors tuple of three tensors
The edge tuple (u, v, eid). The edge tuple :math:`(src, dst, eid)`. :math:`src[i],
dst[i], eid[i]` separately specifies the source node,
destination node and the edge id for the ith edge
in the batch.
""" """
if is_all(self._edges[2]): if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange( self._edges[2] = utils.toindex(F.arange(
...@@ -78,7 +82,12 @@ class EdgeBatch(object): ...@@ -78,7 +82,12 @@ class EdgeBatch(object):
return (u.tousertensor(), v.tousertensor(), eid.tousertensor()) return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
def batch_size(self): def batch_size(self):
"""Return the number of edges in this edge batch.""" """Return the number of edges in this edge batch.
Returns
-------
int
"""
return len(self._edges[0]) return len(self._edges[0])
def __len__(self): def __len__(self):
...@@ -86,7 +95,7 @@ class EdgeBatch(object): ...@@ -86,7 +95,7 @@ class EdgeBatch(object):
return self.batch_size() return self.batch_size()
class NodeBatch(object): class NodeBatch(object):
"""The object that represents a batch of nodes. """The class that can represent a batch of nodes.
Parameters Parameters
---------- ----------
...@@ -94,10 +103,12 @@ class NodeBatch(object): ...@@ -94,10 +103,12 @@ class NodeBatch(object):
The graph object. The graph object.
nodes : utils.Index or ALL nodes : utils.Index or ALL
The node ids. The node ids.
data : dict of tensors data : dict
The node features The node features, in the form of ``dict``
msgs : dict of tensors, optional with ``str`` keys and ``tensor`` values
The messages. msgs : dict, optional
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, nodes, data, msgs=None): def __init__(self, g, nodes, data, msgs=None):
self._g = g self._g = g
...@@ -111,8 +122,8 @@ class NodeBatch(object): ...@@ -111,8 +122,8 @@ class NodeBatch(object):
Returns Returns
------- -------
dict of str to tensors dict with str keys and tensor values
The feature data. Features of the nodes.
""" """
return self._data return self._data
...@@ -120,18 +131,19 @@ class NodeBatch(object): ...@@ -120,18 +131,19 @@ class NodeBatch(object):
def mailbox(self): def mailbox(self):
"""Return the received messages. """Return the received messages.
If no messages received, a None will be returned. If no messages received, a ``None`` will be returned.
Returns Returns
------- -------
dict of str to tensors dict or None
The message data. The messages nodes received. If dict, the keys are
``str`` and the values are ``tensor``.
""" """
return self._msgs return self._msgs
def nodes(self): def nodes(self):
"""Return the nodes contained in this batch. """Return the nodes contained in this batch.
Returns Returns
------- -------
tensor tensor
...@@ -143,7 +155,12 @@ class NodeBatch(object): ...@@ -143,7 +155,12 @@ class NodeBatch(object):
return self._nodes.tousertensor() return self._nodes.tousertensor()
def batch_size(self): def batch_size(self):
"""Return the number of nodes in this node batch.""" """Return the number of nodes in this batch.
Returns
-------
int
"""
if is_all(self._nodes): if is_all(self._nodes):
return self._g.number_of_nodes() return self._g.number_of_nodes()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment