udf.py 4.44 KB
Newer Older
1
2
3
4
5
6
7
8
"""User-defined function related data structures."""
from __future__ import absolute_import

from .base import ALL, is_all
from . import backend as F
from . import utils

class EdgeBatch(object):
Mufei Li's avatar
Mufei Li committed
9
    """The class that can represent a batch of edges.
10
11
12
13
14
15
16

    Parameters
    ----------
    g : DGLGraph
        The graph object.
    edges : tuple of utils.Index
        The edge tuple (u, v, eid). eid can be ALL
Mufei Li's avatar
Mufei Li committed
17
18
19
20
21
22
    src_data : dict
        The src node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
    edge_data : dict
        The edge features, in the form of ``dict`` with
        ``str`` keys and ``tensor`` values
23
    dst_data : dict of tensors
Mufei Li's avatar
Mufei Li committed
24
25
        The dst node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    """
    def __init__(self, g, edges, src_data, edge_data, dst_data):
        self._g = g
        self._edges = edges
        self._src_data = src_data
        self._edge_data = edge_data
        self._dst_data = dst_data

    @property
    def src(self):
        """Return the feature data of the source nodes.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
40
41
        dict with str keys and tensor values
            Features of the source nodes.
42
43
44
45
46
47
48
49
50
        """
        return self._src_data

    @property
    def dst(self):
        """Return the feature data of the destination nodes.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
51
52
        dict with str keys and tensor values
            Features of the destination nodes.
53
54
55
56
57
58
59
60
61
        """
        return self._dst_data

    @property
    def data(self):
        """Return the edge feature data.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
62
63
        dict with str keys and tensor values
            Features of the edges.
64
65
66
67
68
        """
        return self._edge_data

    def edges(self):
        """Return the edges contained in this batch.
Mufei Li's avatar
Mufei Li committed
69

70
71
        Returns
        -------
Mufei Li's avatar
Mufei Li committed
72
73
74
75
76
        tuple of three tensors
            The edge tuple :math:`(src, dst, eid)`. :math:`src[i],
            dst[i], eid[i]` separately specifies the source node,
            destination node and the edge id for the ith edge
            in the batch.
77
78
79
        """
        if is_all(self._edges[2]):
            self._edges[2] = utils.toindex(F.arange(
80
                0, self._g.number_of_edges()))
81
82
83
84
        u, v, eid = self._edges
        return (u.tousertensor(), v.tousertensor(), eid.tousertensor())

    def batch_size(self):
Mufei Li's avatar
Mufei Li committed
85
86
87
88
89
90
        """Return the number of edges in this edge batch.

        Returns
        -------
        int
        """
91
92
93
94
95
96
97
        return len(self._edges[0])

    def __len__(self):
        """Return the number of edges in this edge batch."""
        return self.batch_size()

class NodeBatch(object):
Mufei Li's avatar
Mufei Li committed
98
    """The class that can represent a batch of nodes.
99
100
101
102
103
104
105

    Parameters
    ----------
    g : DGLGraph
        The graph object.
    nodes : utils.Index or ALL
        The node ids.
Mufei Li's avatar
Mufei Li committed
106
107
108
109
110
111
    data : dict
        The node features, in the form of ``dict``
        with ``str`` keys and ``tensor`` values
    msgs : dict, optional
        The messages, , in the form of ``dict``
        with ``str`` keys and ``tensor`` values
112
113
114
115
116
117
118
119
120
121
122
123
124
    """
    def __init__(self, g, nodes, data, msgs=None):
        self._g = g
        self._nodes = nodes
        self._data = data
        self._msgs = msgs

    @property
    def data(self):
        """Return the node feature data.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
125
126
        dict with str keys and tensor values
            Features of the nodes.
127
128
129
130
131
132
133
        """
        return self._data

    @property
    def mailbox(self):
        """Return the received messages.

Mufei Li's avatar
Mufei Li committed
134
        If no messages received, a ``None`` will be returned.
135
136
137

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
138
139
140
        dict or None
            The messages nodes received. If dict, the keys are
            ``str`` and the values are ``tensor``.
141
142
143
144
145
        """
        return self._msgs

    def nodes(self):
        """Return the nodes contained in this batch.
Mufei Li's avatar
Mufei Li committed
146

147
148
149
150
151
152
153
        Returns
        -------
        tensor
            The nodes.
        """
        if is_all(self._nodes):
            self._nodes = utils.toindex(F.arange(
154
                0, self._g.number_of_nodes()))
155
156
157
        return self._nodes.tousertensor()

    def batch_size(self):
Mufei Li's avatar
Mufei Li committed
158
159
160
161
162
163
        """Return the number of nodes in this batch.

        Returns
        -------
        int
        """
164
165
166
167
168
169
170
171
        if is_all(self._nodes):
            return self._g.number_of_nodes()
        else:
            return len(self._nodes)

    def __len__(self):
        """Return the number of nodes in this node batch."""
        return self.batch_size()