udf.py 3.99 KB
Newer Older
1
2
3
4
"""User-defined function related data structures."""
from __future__ import absolute_import

class EdgeBatch(object):
Mufei Li's avatar
Mufei Li committed
5
    """The class that can represent a batch of edges.
6
7
8
9
10

    Parameters
    ----------
    edges : tuple of utils.Index
        The edge tuple (u, v, eid). eid can be ALL
Mufei Li's avatar
Mufei Li committed
11
12
13
14
15
16
    src_data : dict
        The src node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
    edge_data : dict
        The edge features, in the form of ``dict`` with
        ``str`` keys and ``tensor`` values
17
    dst_data : dict of tensors
Mufei Li's avatar
Mufei Li committed
18
19
        The dst node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
20
    """
Minjie Wang's avatar
Minjie Wang committed
21
    def __init__(self, edges, src_data, edge_data, dst_data):
22
23
24
25
26
27
28
29
30
31
32
        self._edges = edges
        self._src_data = src_data
        self._edge_data = edge_data
        self._dst_data = dst_data

    @property
    def src(self):
        """Return the feature data of the source nodes.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
33
34
        dict with str keys and tensor values
            Features of the source nodes.
35
36
37
38
39
40
41
42
43
        """
        return self._src_data

    @property
    def dst(self):
        """Return the feature data of the destination nodes.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
44
45
        dict with str keys and tensor values
            Features of the destination nodes.
46
47
48
49
50
51
52
53
54
        """
        return self._dst_data

    @property
    def data(self):
        """Return the edge feature data.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
55
56
        dict with str keys and tensor values
            Features of the edges.
57
58
59
60
61
        """
        return self._edge_data

    def edges(self):
        """Return the edges contained in this batch.
Mufei Li's avatar
Mufei Li committed
62

63
64
        Returns
        -------
Mufei Li's avatar
Mufei Li committed
65
66
67
68
69
        tuple of three tensors
            The edge tuple :math:`(src, dst, eid)`. :math:`src[i],
            dst[i], eid[i]` separately specifies the source node,
            destination node and the edge id for the ith edge
            in the batch.
70
71
72
73
74
        """
        u, v, eid = self._edges
        return (u.tousertensor(), v.tousertensor(), eid.tousertensor())

    def batch_size(self):
Mufei Li's avatar
Mufei Li committed
75
76
77
78
79
80
        """Return the number of edges in this edge batch.

        Returns
        -------
        int
        """
81
82
83
        return len(self._edges[0])

    def __len__(self):
84
85
86
87
88
89
        """Return the number of edges in this edge batch.

        Returns
        -------
        int
        """
90
91
92
        return self.batch_size()

class NodeBatch(object):
Mufei Li's avatar
Mufei Li committed
93
    """The class that can represent a batch of nodes.
94
95
96

    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
97
    nodes : utils.Index
98
        The node ids.
Mufei Li's avatar
Mufei Li committed
99
100
101
102
103
104
    data : dict
        The node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
    msgs : dict, optional
        The messages, , in the form of ``dict``
        with ``str`` keys and ``tensor`` values
105
    """
Minjie Wang's avatar
Minjie Wang committed
106
    def __init__(self, nodes, data, msgs=None):
107
108
109
110
111
112
113
114
115
116
        self._nodes = nodes
        self._data = data
        self._msgs = msgs

    @property
    def data(self):
        """Return the node feature data.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
117
118
        dict with str keys and tensor values
            Features of the nodes.
119
120
121
122
123
124
125
        """
        return self._data

    @property
    def mailbox(self):
        """Return the received messages.

Mufei Li's avatar
Mufei Li committed
126
        If no messages received, a ``None`` will be returned.
127
128
129

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
130
131
132
        dict or None
            The messages nodes received. If dict, the keys are
            ``str`` and the values are ``tensor``.
133
134
135
136
137
        """
        return self._msgs

    def nodes(self):
        """Return the nodes contained in this batch.
Mufei Li's avatar
Mufei Li committed
138

139
140
141
142
143
144
145
146
        Returns
        -------
        tensor
            The nodes.
        """
        return self._nodes.tousertensor()

    def batch_size(self):
Mufei Li's avatar
Mufei Li committed
147
148
149
150
151
152
        """Return the number of nodes in this batch.

        Returns
        -------
        int
        """
Minjie Wang's avatar
Minjie Wang committed
153
        return len(self._nodes)
154
155

    def __len__(self):
156
157
158
159
160
161
        """Return the number of nodes in this node batch.

        Returns
        -------
        int
        """
162
        return self.batch_size()