"encoding/vscode:/vscode.git/clone" did not exist on "ed5456d3e01c679ed7e979708dff04cd88722376"
heterograph.py 157 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
"""Classes for heterogeneous graphs."""
2
#pylint: disable= too-many-lines
3
from collections import defaultdict
Minjie Wang's avatar
Minjie Wang committed
4
from contextlib import contextmanager
5
import networkx as nx
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
import numpy as np

from . import graph_index
from . import heterograph_index
10
11
12
from . import utils
from . import backend as F
from . import init
Minjie Wang's avatar
Minjie Wang committed
13
14
from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer
15
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
Mufei Li's avatar
Mufei Li committed
16
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
Mufei Li's avatar
Mufei Li committed
17
from .udf import NodeBatch, EdgeBatch
Minjie Wang's avatar
Minjie Wang committed
18
19
20
21
22
23

__all__ = ['DGLHeteroGraph', 'combine_names']

class DGLHeteroGraph(object):
    """Base heterogeneous graph class.

Mufei Li's avatar
Mufei Li committed
24
25
    **Do NOT instantiate from this class directly; use** :mod:`conversion methods
    <dgl.convert>` **instead.**
Minjie Wang's avatar
Minjie Wang committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    A Heterogeneous graph is defined as a graph with node types and edge
    types.

    If two edges share the same edge type, then their source nodes, as well
    as their destination nodes, also have the same type (the source node
    types don't have to be the same as the destination node types).

    Examples
    --------
    Suppose that we want to construct the following heterogeneous graph:

    .. graphviz::

       digraph G {
           Alice -> Bob [label=follows]
           Bob -> Carol [label=follows]
           Alice -> Tetris [label=plays]
           Bob -> Tetris [label=plays]
           Bob -> Minecraft [label=plays]
           Carol -> Minecraft [label=plays]
           Nintendo -> Tetris [label=develops]
           Mojang -> Minecraft [label=develops]
           {rank=source; Alice; Bob; Carol}
           {rank=sink; Nintendo; Mojang}
       }

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
53
    And suppose that one maps the users, games and developers to the following
Minjie Wang's avatar
Minjie Wang committed
54
55
    IDs:

Mufei Li's avatar
Mufei Li committed
56
57
58
59
60
    =========  =====  ===  =====
    User name  Alice  Bob  Carol
    =========  =====  ===  =====
    User ID    0      1    2
    =========  =====  ===  =====
Minjie Wang's avatar
Minjie Wang committed
61

Mufei Li's avatar
Mufei Li committed
62
63
64
65
66
    =========  ======  =========
    Game name  Tetris  Minecraft
    =========  ======  =========
    Game ID    0       1
    =========  ======  =========
Minjie Wang's avatar
Minjie Wang committed
67

Mufei Li's avatar
Mufei Li committed
68
69
70
71
72
    ==============  ========  ======
    Developer name  Nintendo  Mojang
    ==============  ========  ======
    Developer ID    0         1
    ==============  ========  ======
Minjie Wang's avatar
Minjie Wang committed
73
74
75
76
77

    One can construct the graph as follows:

    >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
Mufei Li's avatar
Mufei Li committed
78
79
    >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    >>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
Minjie Wang's avatar
Minjie Wang committed
80

81
82
83
84
85
86
87
88
    Or equivalently

    >>> g = dgl.heterograph({
    ...     ('user', 'follows', 'user'): [(0, 1), (1, 2)],
    ...     ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
    ...     ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
    ...     })

Minjie Wang's avatar
Minjie Wang committed
89
    :func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
Mufei Li's avatar
Mufei Li committed
90
91
92
93
94
95
96
97
    data types including:

    * edge list
    * edge tuples
    * networkx graph
    * scipy sparse matrix

    Click the function names for more details.
Minjie Wang's avatar
Minjie Wang committed
98
99
100
101
102
103
104
105
106
107
108
109
110

    Then one can query the graph structure by specifying the ``ntype`` or ``etype`` arguments:

    >>> g.number_of_nodes('user')
    3
    >>> g.number_of_edges('plays')
    4
    >>> g.out_degrees(etype='develops')  # out-degrees of source nodes of 'develops' relation
    tensor([1, 1])
    >>> g.in_edges(0, etype='develops')  # in-edges of destination node 0 of 'develops' relation
    (tensor([0]), tensor([0]))

    Or on the sliced graph for an edge type:
111

Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
116
    >>> g['plays'].number_of_edges()
    4
    >>> g['develops'].out_degrees()
    tensor([1, 1])
    >>> g['develops'].in_edges(0)
Mufei Li's avatar
Mufei Li committed
117
    (tensor([0]), tensor([0]))
Minjie Wang's avatar
Minjie Wang committed
118
119
120
121
122
123
124
125
126
127

    Node type names must be distinct (no two types have the same name). Edge types could
    have the same name but they must be distinguishable by the ``(src_type, edge_type, dst_type)``
    triplet (called *canonical edge type*).

    For example, suppose a graph that has two types of relation "user-watches-movie"
    and "user-watches-TV" as follows:

    >>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
    >>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
Mufei Li's avatar
Mufei Li committed
128
    >>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs
Minjie Wang's avatar
Minjie Wang committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    To distinguish between the two "watches" edge type, one must specify a full triplet:

    >>> GG.number_of_edges(('user', 'watches', 'movie'))
    3
    >>> GG.number_of_edges(('user', 'watches', 'TV'))
    2
    >>> GG['user', 'watches', 'movie'].out_degrees()
    tensor([1, 2])

    Using only one single edge type string "watches" is ambiguous and will cause error:

    >>> GG.number_of_edges('watches')  # AMBIGUOUS!!

    In many cases, there is only one type of nodes or one type of edges, and the ``ntype``
    and ``etype`` argument could be omitted. This is very common when using the sliced
    graph, which usually contains only one edge type, and sometimes only one node type:
146

Minjie Wang's avatar
Minjie Wang committed
147
148
149
150
    >>> g['follows'].number_of_nodes()  # OK!! because g['follows'] only has one node type 'user'
    3
    >>> g['plays'].number_of_nodes()  # ERROR!! There are two types 'user' and 'game'.
    >>> g['plays'].number_of_edges()  # OK!! because there is only one edge type 'plays'
Da Zheng's avatar
Da Zheng committed
151

152
153
    TODO(minjie): docstring about uni-directional bipartite graph

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    Metagraph
    ---------
    For each heterogeneous graph, one can often infer the *metagraph*, the template of
    edge connections showing how many types of nodes and edges exist in the graph, and
    how each edge type could connect between node types.

    One can analyze the example gameplay graph above and figure out the metagraph as
    follows:

    .. graphviz::

       digraph G {
           User -> User [label=follows]
           User -> Game [label=plays]
           Developer -> Game [label=develops]
       }


Da Zheng's avatar
Da Zheng committed
172
173
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
174
175
    gidx : HeteroGraphIndex
        Graph index object.
176
    ntypes : list of str, pair of list of str
Mufei Li's avatar
Mufei Li committed
177
        Node type list. ``ntypes[i]`` stores the name of node type i.
178
179
        If a pair is given, the graph created is a uni-directional bipartite graph,
        and its SRC node types and DST node types are given as in the pair.
Minjie Wang's avatar
Minjie Wang committed
180
    etypes : list of str
Mufei Li's avatar
Mufei Li committed
181
        Edge type list. ``etypes[i]`` stores the name of edge type i.
Minjie Wang's avatar
Minjie Wang committed
182
    node_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
183
184
185
        Node feature storage. If None, empty frame is created.
        Otherwise, ``node_frames[i]`` stores the node features
        of node type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
186
    edge_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
187
188
189
        Edge feature storage. If None, empty frame is created.
        Otherwise, ``edge_frames[i]`` stores the edge features
        of edge type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
190
    """
Da Zheng's avatar
Da Zheng committed
191
    # pylint: disable=unused-argument
Minjie Wang's avatar
Minjie Wang committed
192
193
194
195
196
    def __init__(self,
                 gidx,
                 ntypes,
                 etypes,
                 node_frames=None,
197
198
                 edge_frames=None):
        self._init(gidx, ntypes, etypes, node_frames, edge_frames)
Da Zheng's avatar
Da Zheng committed
199

200
201
    def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
        """Init internal states."""
Minjie Wang's avatar
Minjie Wang committed
202
        self._graph = gidx
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        # Handle node types
        if isinstance(ntypes, tuple):
            if len(ntypes) != 2:
                errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format(
                    ntypes)
                raise TypeError(errmsg)
            if not is_unibipartite(self._graph.metagraph):
                raise ValueError('Invalid input. The metagraph must be a uni-directional'
                                 ' bipartite graph.')
            self._ntypes = ntypes[0] + ntypes[1]
            self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
            self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
            self._is_unibipartite = True
        else:
            self._ntypes = ntypes
            src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
            self._is_unibipartite = (src_dst_map is not None)
            if self._is_unibipartite:
                self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
            else:
                self._srctypes_invmap = {t : i for i, t in enumerate(self._ntypes)}
                self._dsttypes_invmap = self._srctypes_invmap

        # Handle edge types
228
        self._etypes = etypes
229
230
        self._canonical_etypes = make_canonical_etypes(
            self._etypes, self._ntypes, self._graph.metagraph)
231

Minjie Wang's avatar
Minjie Wang committed
232
        # An internal map from etype to canonical etype tuple.
233
234
        # If two etypes have the same name, an empty tuple is stored instead to indicate
        # ambiguity.
Minjie Wang's avatar
Minjie Wang committed
235
        self._etype2canonical = {}
236
        for i, ety in enumerate(self._etypes):
Minjie Wang's avatar
Minjie Wang committed
237
238
239
240
241
            if ety in self._etype2canonical:
                self._etype2canonical[ety] = tuple()
            else:
                self._etype2canonical[ety] = self._canonical_etypes[i]
        self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
Da Zheng's avatar
Da Zheng committed
242

243
244
245
        # Cached metagraph in networkx
        self._nx_metagraph = None

Minjie Wang's avatar
Minjie Wang committed
246
247
248
249
250
251
252
        # node and edge frame
        if node_frames is None:
            node_frames = [None] * len(self._ntypes)
        node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
                       if frame is None else frame
                       for i, frame in enumerate(node_frames)]
        self._node_frames = node_frames
Da Zheng's avatar
Da Zheng committed
253

Minjie Wang's avatar
Minjie Wang committed
254
255
256
257
258
259
        if edge_frames is None:
            edge_frames = [None] * len(self._etypes)
        edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
                       if frame is None else frame
                       for i, frame in enumerate(edge_frames)]
        self._edge_frames = edge_frames
Da Zheng's avatar
Da Zheng committed
260

Minjie Wang's avatar
Minjie Wang committed
261
262
263
264
265
266
267
        # message indicators
        self._msg_indices = [None] * len(self._etypes)
        self._msg_frames = []
        for i in range(len(self._etypes)):
            frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
            frame.set_initializer(init.zero_initializer)
            self._msg_frames.append(frame)
Da Zheng's avatar
Da Zheng committed
268

269
270
271
272
273
    def __getstate__(self):
        return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames

    def __setstate__(self, state):
        self._init(*state)
Mufei Li's avatar
Mufei Li committed
274

Minjie Wang's avatar
Minjie Wang committed
275
    def _get_msg_index(self, etid):
276
        """Internal function for getting the message index array of the given edge type id."""
Minjie Wang's avatar
Minjie Wang committed
277
278
279
280
        if self._msg_indices[etid] is None:
            self._msg_indices[etid] = utils.zero_index(
                size=self._graph.number_of_edges(etid))
        return self._msg_indices[etid]
Da Zheng's avatar
Da Zheng committed
281

Minjie Wang's avatar
Minjie Wang committed
282
283
    def _set_msg_index(self, etid, index):
        self._msg_indices[etid] = index
Da Zheng's avatar
Da Zheng committed
284

Minjie Wang's avatar
Minjie Wang committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def __repr__(self):
        if len(self.ntypes) == 1 and len(self.etypes) == 1:
            ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
                   '      ndata_schemes={ndata}\n'
                   '      edata_schemes={edata})')
            return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
                              ndata=str(self.node_attr_schemes()),
                              edata=str(self.edge_attr_schemes()))
        else:
            ret = ('Graph(num_nodes={node},\n'
                   '      num_edges={edge},\n'
                   '      metagraph={meta})')
            nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
                          for i in range(len(self.ntypes))}
299
            nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
Minjie Wang's avatar
Minjie Wang committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                          for i in range(len(self.etypes))}
            meta = str(self.metagraph.edges())
            return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)

    #################################################################
    # Mutation operations
    #################################################################

    def add_nodes(self, num, data=None, ntype=None):
        """Add multiple new nodes of the same node type

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edge(self, u, v, data=None, etype=None):
        """Add an edge of ``etype`` between u of the source node type, and v
        of the destination node type..

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edges(self, u, v, data=None, etype=None):
        """Add multiple edges of ``etype`` between list of source nodes ``u``
        and list of destination nodes ``v`` of type ``vtype``.  A single edge
        is added between every pair of ``u[i]`` and ``v[i]``.

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    #################################################################
    # Metagraph query
    #################################################################
Da Zheng's avatar
Da Zheng committed
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    @property
    def is_unibipartite(self):
        """Return whether the graph is a uni-bipartite graph.

        A uni-bipartite heterograph can further divide its node types into two sets:
        SRC and DST. All edges are from nodes in SRC to nodes in DST. The following APIs
        can be used to get the nodes and types that belong to SRC and DST sets:

        * :func:`srctype` and :func:`dsttype`
        * :func:`srcdata` and :func:`dstdata`
        * :func:`srcnodes` and :func:`dstnodes`

        Note that we allow two node types to have the same name as long as one
        belongs to SRC while the other belongs to DST. To distinguish them, prepend
        the name with ``"SRC/"`` or ``"DST/"`` when specifying a node type.
        """
        return self._is_unibipartite

354
    @property
Minjie Wang's avatar
Minjie Wang committed
355
    def ntypes(self):
Mufei Li's avatar
Mufei Li committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.ntypes
        ['user', 'game']
        """
371
        return self._ntypes
Da Zheng's avatar
Da Zheng committed
372

373
    @property
Minjie Wang's avatar
Minjie Wang committed
374
    def etypes(self):
Mufei Li's avatar
Mufei Li committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.etypes
        ['follows', 'plays']
        """
390
        return self._etypes
Da Zheng's avatar
Da Zheng committed
391

Minjie Wang's avatar
Minjie Wang committed
392
393
394
395
396
    @property
    def canonical_etypes(self):
        """Return the list of canonical edge types of this graph.

        A canonical edge type is a tuple of string (src_type, edge_type, dst_type).
Mufei Li's avatar
Mufei Li committed
397
398
399
400
401
402
403
404
405
406
407
408
409

        Returns
        -------
        list of 3-tuples

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.canonical_etypes
        [('user', 'follows', 'user'), ('user', 'plays', 'game')]
Minjie Wang's avatar
Minjie Wang committed
410
411
412
        """
        return self._canonical_etypes

413
    @property
414
415
416
417
418
419
420
421
    def srctypes(self):
        """Return the node types in the SRC category. Return :attr:``ntypes`` if
        the graph is not a uni-bipartite graph.
        """
        if self.is_unibipartite:
            return sorted(list(self._srctypes_invmap.keys()))
        else:
            return self.ntypes
422
423

    @property
424
425
426
427
428
429
430
431
    def dsttypes(self):
        """Return the node types in the DST category. Return :attr:``ntypes`` if
        the graph is not a uni-bipartite graph.
        """
        if self.is_unibipartite:
            return sorted(list(self._dsttypes_invmap.keys()))
        else:
            return self.ntypes
432

Da Zheng's avatar
Da Zheng committed
433
434
    @property
    def metagraph(self):
435
436
437
438
        """Return the metagraph as networkx.MultiDiGraph.

        The nodes are labeled with node type names.
        The edges have their keys holding the edge type names.
Minjie Wang's avatar
Minjie Wang committed
439
440
441
442

        Returns
        -------
        networkx.MultiDiGraph
Mufei Li's avatar
Mufei Li committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> meta_g = g.metagraph

        The metagraph then has two nodes and two edges.

        >>> meta_g.nodes()
        NodeView(('user', 'game'))
        >>> meta_g.number_of_nodes()
        2
        >>> meta_g.edges()
        OutMultiEdgeDataView([('user', 'user'), ('user', 'game')])
        >>> meta_g.number_of_edges()
        2
Minjie Wang's avatar
Minjie Wang committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        """
        if self._nx_metagraph is None:
            nx_graph = self._graph.metagraph.to_networkx()
            self._nx_metagraph = nx.MultiDiGraph()
            for u_v in nx_graph.edges:
                srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
                self._nx_metagraph.add_edge(srctype, dsttype, etype)
        return self._nx_metagraph

    def to_canonical_etype(self, etype):
        """Convert edge type to canonical etype: (srctype, etype, dsttype).

        The input can already be a canonical tuple.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type

        Returns
        -------
        tuple of str
Mufei Li's avatar
Mufei Li committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

        Examples
        --------

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g3 = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'follows', 'game')
        >>> g = dgl.hetero_from_relations([g1, g2, g3])

        Get canonical edge types.

        >>> g.to_canonical_etype('plays')
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype(('user', 'plays', 'game'))
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype('follows')
        DGLError: Edge type "follows" is ambiguous.
        Please use canonical etype type in the form of (srctype, etype, dsttype)
504
        """
Minjie Wang's avatar
Minjie Wang committed
505
506
        if isinstance(etype, tuple):
            return etype
507
        else:
Minjie Wang's avatar
Minjie Wang committed
508
509
510
511
512
513
514
515
516
517
518
519
520
            ret = self._etype2canonical.get(etype, None)
            if ret is None:
                raise DGLError('Edge type "{}" does not exist.'.format(etype))
            if len(ret) == 0:
                raise DGLError('Edge type "%s" is ambiguous. Please use canonical etype '
                               'type in the form of (srctype, etype, dsttype)' % etype)
            return ret

    def get_ntype_id(self, ntype):
        """Return the id of the given node type.

        ntype can also be None. If so, there should be only one node type in the
        graph.
521

Minjie Wang's avatar
Minjie Wang committed
522
523
524
525
        Parameters
        ----------
        ntype : str
            Node type
Da Zheng's avatar
Da Zheng committed
526
527
528

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
529
530
        int
        """
531
        if self.is_unibipartite and ntype is not None:
532
533
534
535
536
537
538
539
            # Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
            if ntype.startswith('SRC/'):
                return self.get_ntype_id_from_src(ntype[4:])
            elif ntype.startswith('DST/'):
                return self.get_ntype_id_from_dst(ntype[4:])
            # If there is no prefix, fallback to normal lookup.

        # Lookup both SRC and DST
Minjie Wang's avatar
Minjie Wang committed
540
        if ntype is None:
541
            if self.is_unibipartite or len(self._srctypes_invmap) != 1:
Minjie Wang's avatar
Minjie Wang committed
542
543
544
                raise DGLError('Node type name must be specified if there are more than one '
                               'node types.')
            return 0
545
        ntid = self._srctypes_invmap.get(ntype, self._dsttypes_invmap.get(ntype, None))
Minjie Wang's avatar
Minjie Wang committed
546
547
548
        if ntid is None:
            raise DGLError('Node type "{}" does not exist.'.format(ntype))
        return ntid
Da Zheng's avatar
Da Zheng committed
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def get_ntype_id_from_src(self, ntype):
        """Return the id of the given SRC node type.

        ntype can also be None. If so, there should be only one node type in the
        SRC category. Callable even when the self graph is not uni-bipartite.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._srctypes_invmap) != 1:
                raise DGLError('SRC node type name must be specified if there are more than one '
                               'SRC node types.')
569
            return next(iter(self._srctypes_invmap.values()))
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        ntid = self._srctypes_invmap.get(ntype, None)
        if ntid is None:
            raise DGLError('SRC node type "{}" does not exist.'.format(ntype))
        return ntid

    def get_ntype_id_from_dst(self, ntype):
        """Return the id of the given DST node type.

        ntype can also be None. If so, there should be only one node type in the
        DST category. Callable even when the self graph is not uni-bipartite.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._dsttypes_invmap) != 1:
                raise DGLError('DST node type name must be specified if there are more than one '
                               'DST node types.')
594
            return next(iter(self._dsttypes_invmap.values()))
595
596
597
598
599
        ntid = self._dsttypes_invmap.get(ntype, None)
        if ntid is None:
            raise DGLError('DST node type "{}" does not exist.'.format(ntype))
        return ntid

Minjie Wang's avatar
Minjie Wang committed
600
601
    def get_etype_id(self, etype):
        """Return the id of the given edge type.
602

Minjie Wang's avatar
Minjie Wang committed
603
604
605
606
607
608
609
        etype can also be None. If so, there should be only one edge type in the
        graph.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type
Da Zheng's avatar
Da Zheng committed
610

611
612
        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
613
614
615
616
617
618
619
620
621
622
623
        int
        """
        if etype is None:
            if self._graph.number_of_etypes() != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            return 0
        etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
        if etid is None:
            raise DGLError('Edge type "{}" does not exist.'.format(etype))
        return etid
Da Zheng's avatar
Da Zheng committed
624

Minjie Wang's avatar
Minjie Wang committed
625
626
627
    #################################################################
    # View
    #################################################################
Da Zheng's avatar
Da Zheng committed
628

629
    @property
Minjie Wang's avatar
Minjie Wang committed
630
    def nodes(self):
Mufei Li's avatar
Mufei Li committed
631
632
        """Return a node view that can be used to set/get feature
        data of a single node type.
Da Zheng's avatar
Da Zheng committed
633

Minjie Wang's avatar
Minjie Wang committed
634
635
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
636
637
638
639
640
        The following example uses PyTorch backend.

        To set features of all users

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
641
        >>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
Mufei Li's avatar
Mufei Li committed
642
643
644
645

        See Also
        --------
        ndata
646
        """
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        return HeteroNodeView(self, self.get_ntype_id)

    @property
    def srcnodes(self):
        """Return a SRC node view that can be used to set/get feature
        data of a single node type.

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all users

        >>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5)

        See Also
        --------
        srcdata
        """
        return HeteroNodeView(self, self.get_ntype_id_from_src)

    @property
    def dstnodes(self):
        """Return a DST node view that can be used to set/get feature
        data of a single node type.

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all games

        >>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5)

        See Also
        --------
        dstdata
        """
        return HeteroNodeView(self, self.get_ntype_id_from_dst)
Da Zheng's avatar
Da Zheng committed
688

689
    @property
Minjie Wang's avatar
Minjie Wang committed
690
691
    def ndata(self):
        """Return the data view of all the nodes.
Da Zheng's avatar
Da Zheng committed
692

Mufei Li's avatar
Mufei Li committed
693
        **Only works if the graph has one node type.**
Minjie Wang's avatar
Minjie Wang committed
694
695
696

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
697
698
699
700
701
702
703
704
705
706
707
        The following example uses PyTorch backend.

        To set features of all nodes in a heterogeneous graph
        with only one node type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.ndata['h'] = torch.zeros(3, 5)

        See Also
        --------
        nodes
Da Zheng's avatar
Da Zheng committed
708
        """
709
710
711
        ntid = self.get_ntype_id(None)
        ntype = self.ntypes[0]
        return HeteroNodeDataView(self, ntype, ntid, ALL)
Da Zheng's avatar
Da Zheng committed
712

713
714
    @property
    def srcdata(self):
715
        """Return the data view of all nodes in the SRC category.
716

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
717
718
719
720
721
722
        Only works if the graph is either

        * Uni-bipartite and has one node type in the SRC category.

        * Non-uni-bipartite and has only one node type (in this case identical to
        :any:`DGLHeteroGraph.ndata`)
723
724
725
726
727
728
729
730
731
732
733
734
735
736

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all source nodes in a graph with only one edge type:

        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.srcdata['h'] = torch.zeros(2, 5)

        This is equivalent to

        >>> g.nodes['user'].data['h'] = torch.zeros(2, 5)

737
738
739
740
741
742
743
744
745
746
        Also work on more complex uni-bipartite graph

        >>> g = dgl.heterograph({
        ...     ('user', 'plays', 'game'), [(0, 1), (1, 2)],
        ...     ('user', 'reads', 'book'), [(0, 1), (1, 0)],
        ...     })
        >>> print(g.is_unibipartite)
        True
        >>> g.srcdata['h'] = torch.zeros(2, 5)

747
748
749
750
751
752
753
754
        Notes
        -----
        This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.

        See Also
        --------
        nodes
        """
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
755
756
757
758
        err_msg = (
            'srcdata is only allowed when there is only one %s type.' %
            ('SRC' if self.is_unibipartite else 'node'))
        assert len(self.srctypes) == 1, err_msg
759
760
761
        ntype = self.srctypes[0]
        ntid = self.get_ntype_id_from_src(ntype)
        return HeteroNodeDataView(self, ntype, ntid, ALL)
762
763
764
765
766

    @property
    def dstdata(self):
        """Return the data view of all destination nodes.

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
767
768
769
770
771
772
        Only works if the graph is either

        * Uni-bipartite and has one node type in the SRC category.

        * Non-uni-bipartite and has only one node type (in this case identical to
        :any:`DGLHeteroGraph.ndata`)
773
774
775
776
777
778
779
780
781
782
783
784
785
786

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all source nodes in a graph with only one edge type:

        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.dstdata['h'] = torch.zeros(3, 5)

        This is equivalent to

        >>> g.nodes['game'].data['h'] = torch.zeros(3, 5)

787
788
789
790
791
792
793
794
795
796
        Also work on more complex uni-bipartite graph

        >>> g = dgl.heterograph({
        ...     ('user', 'plays', 'game'), [(0, 1), (1, 2)],
        ...     ('store', 'sells', 'game'), [(0, 1), (1, 0)],
        ...     })
        >>> print(g.is_unibipartite)
        True
        >>> g.dstdata['h'] = torch.zeros(3, 5)

797
798
799
800
801
802
803
804
        Notes
        -----
        This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.

        See Also
        --------
        nodes
        """
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
805
806
807
808
        err_msg = (
            'dstdata is only allowed when there is only one %s type.' %
            ('DST' if self.is_unibipartite else 'node'))
        assert len(self.dsttypes) == 1, err_msg
809
810
811
        ntype = self.dsttypes[0]
        ntid = self.get_ntype_id_from_dst(ntype)
        return HeteroNodeDataView(self, ntype, ntid, ALL)
812

813
    @property
Minjie Wang's avatar
Minjie Wang committed
814
    def edges(self):
Mufei Li's avatar
Mufei Li committed
815
816
        """Return an edge view that can be used to set/get feature
        data of a single edge type.
817

Minjie Wang's avatar
Minjie Wang committed
818
819
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
820
821
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
822
        To set features of all "play" relationships:
Mufei Li's avatar
Mufei Li committed
823
824
825
826
827
828
829

        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edges['plays'].data['h'] = torch.zeros(3, 4)

        See Also
        --------
        edata
830
        """
Minjie Wang's avatar
Minjie Wang committed
831
        return HeteroEdgeView(self)
832
833

    @property
Minjie Wang's avatar
Minjie Wang committed
834
835
    def edata(self):
        """Return the data view of all the edges.
836

Mufei Li's avatar
Mufei Li committed
837
        **Only works if the graph has one edge type.**
Minjie Wang's avatar
Minjie Wang committed
838
839
840

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
841
842
843
844
845
846
        The following example uses PyTorch backend.

        To set features of all edges in a heterogeneous graph
        with only one edge type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
847
        >>> g.edata['h'] = torch.zeros(2, 5)
Mufei Li's avatar
Mufei Li committed
848
849
850
851

        See Also
        --------
        edges
852
        """
Minjie Wang's avatar
Minjie Wang committed
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        return HeteroEdgeDataView(self, None, ALL)

    def _find_etypes(self, key):
        etypes = [
            i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
            (key[0] == SLICE_FULL or key[0] == srctype) and
            (key[1] == SLICE_FULL or key[1] == etype) and
            (key[2] == SLICE_FULL or key[2] == dsttype)]
        return etypes

    def __getitem__(self, key):
        """Return the relation slice of this graph.

        A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
        ``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
        slice (``:``) representing wildcard (i.e. any source/edge/destination type).

        A relation slice is a homogeneous (with one node type and one edge type) or
        bipartite (with two node types and one edge type) graph, transformed from
        the original heterogeneous graph.

        If there is only one canonical edge type found, then the returned relation
        slice would be a subgraph induced from the original graph.  That is, it is
        equivalent to ``self.edge_type_subgraph(etype)``.  The node and edge features
        of the returned graph would be shared with thew original graph.

        If there are multiple canonical edge type found, then the source/edge/destination
        node types would be a *concatenation* of original node/edge types.  The
        new source/destination node type would have the concatenation determined by
        :func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
        types as its name.  The source/destination node would be formed by concatenating the
        common features of the original source/destination types, therefore they are not
        shared with the original graph.  Edge type is similar.
        """
        err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
                  "to get view of one relation type. Use : to slice multiple types (e.g. " +\
                  "G['srctype', :, 'dsttype'])."

891
        orig_key = key
Minjie Wang's avatar
Minjie Wang committed
892
893
894
895
896
897
898
        if not isinstance(key, tuple):
            key = (SLICE_FULL, key, SLICE_FULL)

        if len(key) != 3:
            raise DGLError(err_msg)

        etypes = self._find_etypes(key)
899
900
901
902

        if len(etypes) == 0:
            raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))

Minjie Wang's avatar
Minjie Wang committed
903
904
905
        if len(etypes) == 1:
            # no ambiguity: return the unitgraph itself
            srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
906
            stid = self.get_ntype_id_from_src(srctype)
Minjie Wang's avatar
Minjie Wang committed
907
            etid = self.get_etype_id((srctype, etype, dsttype))
908
            dtid = self.get_ntype_id_from_dst(dsttype)
Minjie Wang's avatar
Minjie Wang committed
909
910
911
912
913
914
915
916
917
918
            new_g = self._graph.get_relation_graph(etid)

            if stid == dtid:
                new_ntypes = [srctype]
                new_nframes = [self._node_frames[stid]]
            else:
                new_ntypes = [srctype, dsttype]
                new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
            new_etypes = [etype]
            new_eframes = [self._edge_frames[etid]]
919

Minjie Wang's avatar
Minjie Wang committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
            return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
        else:
            flat = self._graph.flatten_relations(etypes)
            new_g = flat.graph

            # merge frames
            stids = flat.induced_srctype_set.asnumpy()
            dtids = flat.induced_dsttype_set.asnumpy()
            etids = flat.induced_etype_set.asnumpy()
            new_ntypes = [combine_names(self.ntypes, stids)]
            if new_g.number_of_ntypes() == 2:
                new_ntypes.append(combine_names(self.ntypes, dtids))
                new_nframes = [
                    combine_frames(self._node_frames, stids),
                    combine_frames(self._node_frames, dtids)]
            else:
                assert np.array_equal(stids, dtids)
                new_nframes = [combine_frames(self._node_frames, stids)]
            new_etypes = [combine_names(self.etypes, etids)]
            new_eframes = [combine_frames(self._edge_frames, etids)]

            # create new heterograph
            new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)

            src = new_ntypes[0]
            dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
            # put the parent node/edge type and IDs
            new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
            new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
            new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
            new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
            new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
            new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)

            return new_hg

    #################################################################
    # Graph query
    #################################################################

    def number_of_nodes(self, ntype=None):
961
        """Return the number of nodes of the given type in the heterograph.
Da Zheng's avatar
Da Zheng committed
962
963
964

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
965
966
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
967
            in the graph. (Default: None)
968
969
970
971
972

        Returns
        -------
        int
            The number of nodes
Da Zheng's avatar
Da Zheng committed
973
974
975

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
976
977
978
979
980

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_nodes('user')
        3
        >>> g.number_of_nodes()
981
        3
Da Zheng's avatar
Da Zheng committed
982
        """
Minjie Wang's avatar
Minjie Wang committed
983
        return self._graph.number_of_nodes(self.get_ntype_id(ntype))
984

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
    def number_of_src_nodes(self, ntype=None):
        """Return the number of nodes of the given SRC node type in the heterograph.

        The heterograph is usually a unidirectional bipartite graph.

        Parameters
        ----------
        ntype : str, optional
            Node type.
            If omitted, there should be only one node type in the SRC category.

        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.number_of_src_nodes('user')
        2
        >>> g.number_of_src_nodes()
        2
        >>> g.number_of_nodes('user')
        2
        """
        return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype))

    def number_of_dst_nodes(self, ntype=None):
        """Return the number of nodes of the given DST node type in the heterograph.

        The heterograph is usually a unidirectional bipartite graph.

        Parameters
        ----------
        ntype : str, optional
            Node type.
            If omitted, there should be only one node type in the DST category.

        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.number_of_dst_nodes('game')
        3
        >>> g.number_of_dst_nodes()
        3
        >>> g.number_of_nodes('game')
        3
        """
        return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype))

Minjie Wang's avatar
Minjie Wang committed
1041
    def number_of_edges(self, etype=None):
1042
1043
1044
1045
        """Return the number of edges of the given type in the heterograph.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
1046
1047
1048
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1049
1050
1051
1052
1053
1054

        Returns
        -------
        int
            The number of edges

1055
1056
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1057
1058
1059
1060
1061
1062
1063

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_edges(('user', 'follows', 'user'))
        >>> g.number_of_edges('follows')
        2
        >>> g.number_of_edges()
        2
1064
        """
Minjie Wang's avatar
Minjie Wang committed
1065
1066
1067
1068
        return self._graph.number_of_edges(self.get_etype_id(etype))

    @property
    def is_multigraph(self):
Mufei Li's avatar
Mufei Li committed
1069
1070
1071
1072
1073
1074
1075
        """Whether the graph is a multigraph

        Returns
        -------
        bool
            True if the graph is a multigraph, False otherwise.
        """
1076
        return self._graph.is_multigraph()
Minjie Wang's avatar
Minjie Wang committed
1077
1078
1079

    @property
    def is_readonly(self):
Mufei Li's avatar
Mufei Li committed
1080
1081
1082
1083
1084
1085
1086
        """Whether the graph is readonly

        Returns
        -------
        bool
            True if the graph is readonly, False otherwise.
        """
Minjie Wang's avatar
Minjie Wang committed
1087
        return self._graph.is_readonly()
Da Zheng's avatar
Da Zheng committed
1088

Minjie Wang's avatar
Minjie Wang committed
1089
    def has_node(self, vid, ntype=None):
Mufei Li's avatar
Mufei Li committed
1090
        """Whether the graph has a node with a particular id and type.
Da Zheng's avatar
Da Zheng committed
1091
1092
1093
1094
1095

        Parameters
        ----------
        vid : int
            The node ID.
Minjie Wang's avatar
Minjie Wang committed
1096
1097
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
1098
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1099
1100
1101
1102

        Returns
        -------
        bool
Mufei Li's avatar
Mufei Li committed
1103
            True if the node exists, False otherwise
Da Zheng's avatar
Da Zheng committed
1104
1105
1106

        Examples
        --------
Minjie Wang's avatar
Minjie Wang committed
1107
        >>> g.has_node(0, 'user')
Da Zheng's avatar
Da Zheng committed
1108
        True
Minjie Wang's avatar
Minjie Wang committed
1109
        >>> g.has_node(4, 'user')
Da Zheng's avatar
Da Zheng committed
1110
1111
1112
1113
1114
1115
        False

        See Also
        --------
        has_nodes
        """
Minjie Wang's avatar
Minjie Wang committed
1116
        return self._graph.has_node(self.get_ntype_id(ntype), vid)
Da Zheng's avatar
Da Zheng committed
1117

Minjie Wang's avatar
Minjie Wang committed
1118
    def has_nodes(self, vids, ntype=None):
Mufei Li's avatar
Mufei Li committed
1119
        """Whether the graph has nodes with ids and a particular type.
Da Zheng's avatar
Da Zheng committed
1120
1121
1122
1123
1124

        Parameters
        ----------
        vid : list or tensor
            The array of node IDs.
Minjie Wang's avatar
Minjie Wang committed
1125
1126
1127
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1128
1129
1130
1131

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
1132
1133
            Binary tensor indicating the existence of nodes with the specified ids and type.
            ``a[i]=1`` if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
1134
1135
1136
1137
1138

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
1139
        >>> g.has_nodes([0, 1, 2, 3, 4], 'user')
Da Zheng's avatar
Da Zheng committed
1140
1141
1142
1143
1144
1145
        tensor([1, 1, 1, 0, 0])

        See Also
        --------
        has_node
        """
1146
        vids = utils.toindex(vids)
Minjie Wang's avatar
Minjie Wang committed
1147
        rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
1148
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1149

Minjie Wang's avatar
Minjie Wang committed
1150
    def has_edge_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1151
        """Whether the graph has an edge (u, v) of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1152
1153
1154
1155
1156
1157
1158

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1159
1160
1161
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1162
1163
1164
1165
1166
1167
1168
1169

        Returns
        -------
        bool
            True if the edge is in the graph, False otherwise.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1170

Minjie Wang's avatar
Minjie Wang committed
1171
        >>> g.has_edge_between(0, 1, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1172
        True
Minjie Wang's avatar
Minjie Wang committed
1173
        >>> g.has_edge_between(0, 2, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1174
1175
1176
1177
1178
1179
        False

        See Also
        --------
        has_edges_between
        """
Minjie Wang's avatar
Minjie Wang committed
1180
        return self._graph.has_edge_between(self.get_etype_id(etype), u, v)
Da Zheng's avatar
Da Zheng committed
1181

Minjie Wang's avatar
Minjie Wang committed
1182
    def has_edges_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1183
        """Whether the graph has edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1184
1185
1186
1187
1188
1189
1190

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
1191
1192
1193
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1194
1195
1196
1197

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
1198
1199
            Binary tensor indicating the existence of edges. ``a[i]=1`` if the graph
            contains edge ``(u[i], v[i])`` of type ``etype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
1200
1201
1202
1203
1204

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
1205
        >>> g.has_edges_between([0, 0], [1, 2], ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1206
1207
1208
1209
1210
1211
        tensor([1, 0])

        See Also
        --------
        has_edge_between
        """
1212
1213
        u = utils.toindex(u)
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1214
        rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
1215
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1216

Minjie Wang's avatar
Minjie Wang committed
1217
    def predecessors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1218
        """Return the predecessors of node `v` in the graph with the specified
Da Zheng's avatar
Da Zheng committed
1219
1220
        edge type.

Mufei Li's avatar
Mufei Li committed
1221
1222
        Node `u` is a predecessor of `v` if an edge `(u, v)` with type `etype`
        exists in the graph.
Da Zheng's avatar
Da Zheng committed
1223
1224
1225
1226

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
1227
            The destination node.
Minjie Wang's avatar
Minjie Wang committed
1228
1229
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1230
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1231
1232
1233
1234

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
1235
            Array of predecessor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1236
1237
1238
1239
1240

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1241
1242
1243
1244
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([plays_g, devs_g])
        >>> g.predecessors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1245
        tensor([0, 1])
Mufei Li's avatar
Mufei Li committed
1246
1247
        >>> g.predecessors(0, 'develops')
        tensor([0])
Da Zheng's avatar
Da Zheng committed
1248
1249
1250
1251
1252

        See Also
        --------
        successors
        """
Minjie Wang's avatar
Minjie Wang committed
1253
        return self._graph.predecessors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1254

Minjie Wang's avatar
Minjie Wang committed
1255
    def successors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1256
        """Return the successors of node `v` in the graph with the specified edge
Da Zheng's avatar
Da Zheng committed
1257
1258
        type.

Mufei Li's avatar
Mufei Li committed
1259
1260
        Node `u` is a successor of `v` if an edge `(v, u)` with type `etype` exists
        in the graph.
Da Zheng's avatar
Da Zheng committed
1261
1262
1263
1264

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
1265
            The source node.
Minjie Wang's avatar
Minjie Wang committed
1266
1267
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1268
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1269
1270
1271
1272

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
1273
            Array of successor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1274
1275
1276
1277
1278

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1279
1280
1281
1282
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> g.successors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1283
        tensor([0])
Mufei Li's avatar
Mufei Li committed
1284
1285
        >>> g.successors(0, 'follows')
        tensor([1])
Da Zheng's avatar
Da Zheng committed
1286
1287
1288
1289
1290

        See Also
        --------
        predecessors
        """
Minjie Wang's avatar
Minjie Wang committed
1291
        return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1292

1293
    def edge_id(self, u, v, force_multi=None, return_array=False, etype=None):
Da Zheng's avatar
Da Zheng committed
1294
        """Return the edge ID, or an array of edge IDs, between source node
Mufei Li's avatar
Mufei Li committed
1295
        `u` and destination node `v`, with the specified edge type
Da Zheng's avatar
Da Zheng committed
1296
1297
1298
1299
1300
1301
1302

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1303
        force_multi : bool, optional
1304
1305
1306
1307
1308
            Deprecated (Will be deleted in the future).
            If False, will return a single edge ID.
            If True, will always return an array. (Default: False)
        return_array : bool, optional
            If False, will return a single edge ID.
Minjie Wang's avatar
Minjie Wang committed
1309
1310
1311
1312
            If True, will always return an array. (Default: False)
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1313
1314
1315
1316

        Returns
        -------
        int or tensor
1317
            The edge ID if ``return_array == False``.
Da Zheng's avatar
Da Zheng committed
1318
1319
            The edge ID array otherwise.

1320
1321
1322
1323
1324
        Notes
        -----
        If multiply edges exist between `u` and `v` and return_array is False,
        the result is undefined.

Da Zheng's avatar
Da Zheng committed
1325
1326
1327
1328
        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge id.

        >>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game'))
        2
1339
        >>> g.edge_id(1, 2, return_array=True, etype=('user', 'follows', 'user'))
Mufei Li's avatar
Mufei Li committed
1340
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
1341
1342
1343
1344
1345

        See Also
        --------
        edge_ids
        """
Minjie Wang's avatar
Minjie Wang committed
1346
        idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
        if force_multi is not None:
            dgl_warning("force_multi will be deprecated." \
                        "Please use return_array instead")
            return_array = force_multi

        if return_array:
            return idx.tousertensor()
        else:
            assert len(idx) == 1, "For return_array=False, there should be one and " \
                "only one edge between u and v, but get {} edges. " \
                "Please use return_array=True instead".format(len(idx))
            return idx[0]
Da Zheng's avatar
Da Zheng committed
1359

1360
    def edge_ids(self, u, v, force_multi=None, return_uv=False, etype=None):
Da Zheng's avatar
Da Zheng committed
1361
        """Return all edge IDs between source node array `u` and destination
Mufei Li's avatar
Mufei Li committed
1362
        node array `v` with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1363
1364
1365
1366
1367
1368
1369

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
1370
        force_multi : bool, optional
1371
            Deprecated (Will be deleted in the future).
Mufei Li's avatar
Mufei Li committed
1372
1373
            Whether to always treat the graph as a multigraph. See the
            "Returns" for their effects. (Default: False)
1374
1375
        return_uv : bool
            See the "Returns" for their effects. (Default: False)
Minjie Wang's avatar
Minjie Wang committed
1376
1377
1378
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1379
1380
1381
1382

        Returns
        -------
        tensor, or (tensor, tensor, tensor)
Mufei Li's avatar
Mufei Li committed
1383

1384
1385
            * If ``return_uv=False``, return a single edge ID array ``e``.
            ``e[i]`` is the edge ID between ``u[i]`` and ``v[i]``.
Mufei Li's avatar
Mufei Li committed
1386
1387
1388
1389

            * Otherwise, return three arrays ``(eu, ev, e)``.  ``e[i]`` is the ID
            of an edge between ``eu[i]`` and ``ev[i]``.  All edges between ``u[i]``
            and ``v[i]`` are returned.
Da Zheng's avatar
Da Zheng committed
1390
1391
1392

        Notes
        -----
1393
        If the graph is a simple graph, ``return_uv=False``, and no edge
Mufei Li's avatar
Mufei Li committed
1394
1395
        exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined
        and an empty tensor is returned.
Da Zheng's avatar
Da Zheng committed
1396

1397
1398
1399
        If the graph is a multi graph, ``return_uv=False``, and multi edges
        exist between some pairs of `u[i]` and `v[i]`, the result is undefined.

Da Zheng's avatar
Da Zheng committed
1400
1401
1402
1403
        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge ids.

        >>> plays_g.edge_ids([0], [2], etype=('user', 'plays', 'game'))
        tensor([], dtype=torch.int64)
        >>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game'))
        tensor([2])
1416
        >>> g.edge_ids([1], [2], return_uv=True, etype=('user', 'follows', 'user'))
Mufei Li's avatar
Mufei Li committed
1417
        (tensor([1, 1]), tensor([2, 2]), tensor([1, 2]))
Da Zheng's avatar
Da Zheng committed
1418
1419
1420
1421
1422

        See Also
        --------
        edge_id
        """
1423
1424
        u = utils.toindex(u)
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1425
        src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
1426
1427
1428
1429
1430
1431
        if force_multi is not None:
            dgl_warning("force_multi will be deprecated, " \
                        "Please use return_uv instead")
            return_uv = force_multi

        if return_uv:
1432
1433
            return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
        else:
1434
1435
1436
            assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
                "only one edge between each u and v, expect {} edges but get {}. " \
                "Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
1437
            return eid.tousertensor()
Da Zheng's avatar
Da Zheng committed
1438

Minjie Wang's avatar
Minjie Wang committed
1439
    def find_edges(self, eid, etype=None):
Mufei Li's avatar
Mufei Li committed
1440
1441
1442
        """Given an edge ID array with the specified type, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.
Da Zheng's avatar
Da Zheng committed
1443
1444
1445
1446
1447

        Parameters
        ----------
        eid : list, tensor
            The edge ID array.
Minjie Wang's avatar
Minjie Wang committed
1448
1449
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1450
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1463
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
1464
        >>> g.find_edges([0, 2], ('user', 'plays', 'game'))
Mufei Li's avatar
Mufei Li committed
1465
1466
1467
        (tensor([0, 1]), tensor([0, 2]))
        >>> g.find_edges([0, 2])
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1468
        """
1469
        eid = utils.toindex(eid)
Minjie Wang's avatar
Minjie Wang committed
1470
        src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
1471
        return src.tousertensor(), dst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1472

Minjie Wang's avatar
Minjie Wang committed
1473
    def in_edges(self, v, form='uv', etype=None):
Mufei Li's avatar
Mufei Li committed
1474
        """Return the inbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1475
1476
1477
1478

        Parameters
        ----------
        v : int, list, tensor
Mufei Li's avatar
Mufei Li committed
1479
            The node id(s) of destination type.
Da Zheng's avatar
Da Zheng committed
1480
1481
1482
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1483
1484
1485
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1486
1487
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1488
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1489
1490
1491

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1492
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
Da Zheng's avatar
Da Zheng committed
1493
            All inbound edges to ``v`` are returned.
Mufei Li's avatar
Mufei Li committed
1494
1495
1496
1497
1498
1499
1500
1501

            * If ``form='eid'``, return a tensor for the ids of the
              inbound edges of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors
              ``(eu, ev, eid)``. ``eid[i]`` gives the ID of the
              edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``eu[i]`` is the source node of an edge to ``ev[i]``.
Da Zheng's avatar
Da Zheng committed
1502
1503
1504
1505
1506

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1507
1508
1509
1510
1511
1512
1513
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.in_edges([0, 2], form='eid')
        tensor([0, 2])
        >>> g.in_edges([0, 2], form='all')
        (tensor([0, 1]), tensor([0, 2]), tensor([0, 2]))
        >>> g.in_edges([0, 2], form='uv')
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1514
        """
1515
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1516
        src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
1517
1518
1519
1520
1521
1522
1523
1524
1525
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Mufei Li's avatar
Mufei Li committed
1526
1527
    def out_edges(self, u, form='uv', etype=None):
        """Return the outbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1528
1529
1530

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1531
1532
        u : int, list, tensor
            The node id(s) of source type.
Da Zheng's avatar
Da Zheng committed
1533
1534
1535
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1536
1537
1538
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1539
1540
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1541
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1542
1543
1544

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1545
1546
1547
1548
1549
1550
1551
1552
1553
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
            All outbound edges from ``u`` are returned.

            * If ``form='eid'``, return a tensor for the ids of the outbound edges
              of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1554
1555
1556
1557

        Examples
        --------

Mufei Li's avatar
Mufei Li committed
1558
1559
1560
1561
1562
1563
1564
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.out_edges([0, 1], form='eid')
        tensor([0, 1, 2])
        >>> g.out_edges([0, 1], form='all')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([0, 1, 2]))
        >>> g.out_edges([0, 1], form='uv')
        (tensor([0, 1, 1]), tensor([0, 1, 2]))
Da Zheng's avatar
Da Zheng committed
1565
        """
Mufei Li's avatar
Mufei Li committed
1566
1567
        u = utils.toindex(u)
        src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
1568
1569
1570
1571
1572
1573
1574
1575
1576
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1577
    def all_edges(self, form='uv', order=None, etype=None):
Mufei Li's avatar
Mufei Li committed
1578
        """Return all edges with the specified type.
Da Zheng's avatar
Da Zheng committed
1579
1580
1581
1582
1583
1584

        Parameters
        ----------
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1585
1586
1587
1588
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
        order : str or None
Da Zheng's avatar
Da Zheng committed
1589
1590
            The order of the returned edges. Currently support:

Mufei Li's avatar
Mufei Li committed
1591
1592
1593
            - ``'srcdst'`` : sorted by their src and dst ids.
            - ``'eid'``    : sorted by edge Ids.
            - ``None``     : arbitrary order, default
Minjie Wang's avatar
Minjie Wang committed
1594
1595
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1596
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1597
1598
1599

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1600
1601
1602
1603
1604
1605
1606
1607
        tensor or (tensor, tensor, tensor) or (tensor, tensor)

            * If ``form='eid'``, return a tensor for the ids of all edges
              with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1608
1609
1610
1611
1612

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1613
1614
1615
1616
1617
1618
1619
        >>> g = dgl.bipartite([(1, 1), (0, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.all_edges(form='eid', order='srcdst')
        tensor([1, 0, 2])
        >>> g.all_edges(form='all', order='srcdst')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([1, 0, 2]))
        >>> g.all_edges(form='uv', order='eid')
        (tensor([1, 0, 1]), tensor([1, 0, 2]))
Da Zheng's avatar
Da Zheng committed
1620
        """
Minjie Wang's avatar
Minjie Wang committed
1621
        src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
1622
1623
1624
1625
1626
1627
1628
1629
1630
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1631
    def in_degree(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1632
        """Return the in-degree of node ``v`` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1633
1634
1635
1636
1637

        Parameters
        ----------
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1638
1639
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1640
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1641
1642
1643
1644
1645
1646
1647
1648

        Returns
        -------
        int
            The in-degree.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1659
        2
Mufei Li's avatar
Mufei Li committed
1660
1661
        >>> g.in_degree(0, 'follows')
        0
Da Zheng's avatar
Da Zheng committed
1662
1663
1664
1665
1666

        See Also
        --------
        in_degrees
        """
Minjie Wang's avatar
Minjie Wang committed
1667
        return self._graph.in_degree(self.get_etype_id(etype), v)
Da Zheng's avatar
Da Zheng committed
1668

Minjie Wang's avatar
Minjie Wang committed
1669
    def in_degrees(self, v=ALL, etype=None):
Mufei Li's avatar
Mufei Li committed
1670
        """Return the in-degrees of nodes v with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1671
1672
1673
1674

        Parameters
        ----------
        v : list, tensor, optional.
Mufei Li's avatar
Mufei Li committed
1675
1676
1677
            The node ID array of the destination type. Default is to return the
            degrees of all nodes.
        etype : str or tuple of str or None, optional
Minjie Wang's avatar
Minjie Wang committed
1678
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1679
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1680
1681
1682
1683

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1684
1685
            The in-degree array. ``d[i]`` gives the in-degree of node ``v[i]``
            with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1686
1687
1688
1689
1690

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degrees(0, 'plays')
        tensor([2])
        >>> g.in_degrees(etype='follows')
        tensor([0, 1, 2])
Da Zheng's avatar
Da Zheng committed
1703
1704
1705
1706
1707

        See Also
        --------
        in_degree
        """
Minjie Wang's avatar
Minjie Wang committed
1708
1709
        etid = self.get_etype_id(etype)
        _, dtid = self._graph.metagraph.find_edge(etid)
1710
        if is_all(v):
Minjie Wang's avatar
Minjie Wang committed
1711
            v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)))
1712
1713
        else:
            v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1714
        return self._graph.in_degrees(etid, v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1715

Mufei Li's avatar
Mufei Li committed
1716
1717
    def out_degree(self, u, etype=None):
        """Return the out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1718
1719
1720

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1721
        u : int
Da Zheng's avatar
Da Zheng committed
1722
            The node ID of source type.
Minjie Wang's avatar
Minjie Wang committed
1723
1724
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1725
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1726
1727
1728
1729

        Returns
        -------
        int
Mufei Li's avatar
Mufei Li committed
1730
            The out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1731
1732
1733

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1744
        1
Mufei Li's avatar
Mufei Li committed
1745
1746
        >>> g.out_degree(1, 'follows')
        2
Da Zheng's avatar
Da Zheng committed
1747

1748
1749
1750
1751
        See Also
        --------
        out_degrees
        """
Mufei Li's avatar
Mufei Li committed
1752
        return self._graph.out_degree(self.get_etype_id(etype), u)
1753

Mufei Li's avatar
Mufei Li committed
1754
1755
    def out_degrees(self, u=ALL, etype=None):
        """Return the out-degrees of nodes u with edges of type ``etype``.
1756
1757
1758

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1759
        u : list, tensor
1760
1761
            The node ID array of source type. Default is to return the degrees
            of all the nodes.
Minjie Wang's avatar
Minjie Wang committed
1762
1763
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1764
            in the graph. (Default: None)
1765
1766
1767
1768

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1769
1770
            The out-degree array. ``d[i]`` gives the out-degree of node ``u[i]``
            with edges of type ``etype``.
1771
1772
1773
1774
1775

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degrees(0, 'plays')
        tensor([1])
        >>> g.out_degrees(etype='follows')
        tensor([1, 2, 0])
1788
1789
1790
1791
1792

        See Also
        --------
        out_degree
        """
Minjie Wang's avatar
Minjie Wang committed
1793
1794
        etid = self.get_etype_id(etype)
        stid, _ = self._graph.metagraph.find_edge(etid)
Mufei Li's avatar
Mufei Li committed
1795
1796
        if is_all(u):
            u = utils.toindex(slice(0, self._graph.number_of_nodes(stid)))
1797
        else:
Mufei Li's avatar
Mufei Li committed
1798
1799
            u = utils.toindex(u)
        return self._graph.out_degrees(etid, u).tousertensor()
Minjie Wang's avatar
Minjie Wang committed
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821

    def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
        """Internal function to create a subgraph."""
        node_frames = [
            FrameRef(Frame(
                self._node_frames[i][induced_nodes_of_ntype],
                num_rows=len(induced_nodes_of_ntype)))
            for i, induced_nodes_of_ntype in enumerate(induced_nodes)]
        edge_frames = [
            FrameRef(Frame(
                self._edge_frames[i][induced_edges_of_etype],
                num_rows=len(induced_edges_of_etype)))
            for i, induced_edges_of_etype in enumerate(induced_edges)]

        hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
        hsg.is_subgraph = True
        for ntype, induced_nid in zip(self.ntypes, induced_nodes):
            hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
        for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
            hsg.edges[etype].data[EID] = induced_eid.tousertensor()

        return hsg
1822

Minjie Wang's avatar
Minjie Wang committed
1823
1824
    def subgraph(self, nodes):
        """Return the subgraph induced on given nodes.
1825

Minjie Wang's avatar
Minjie Wang committed
1826
1827
        The metagraph of the returned subgraph is the same as the parent graph.
        Features are copied from the original graph.
1828

Minjie Wang's avatar
Minjie Wang committed
1829
1830
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1831
1832
1833
        nodes : dict[str->list or iterable]
            A dictionary mapping node types to node ID array for constructing
            subgraph. All nodes must exist in the graph.
1834

Minjie Wang's avatar
Minjie Wang committed
1835
1836
1837
1838
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1839
1840
1841
1842

            The nodes and edges in the subgraph are relabeled using consecutive
            integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1843
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1844
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1845
            subgraph.
Mufei Li's avatar
Mufei Li committed
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.subgraph({'user': [4, 5]})
        An error occurs as these nodes do not exist.
        >>> sub_g = g.subgraph({'user': [1, 2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 0},
1866
              num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['follows'].edata[dgl.EID] # Get the edge indices in the raw graph
        tensor([1, 2])

        Get the copied node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        edge_subgraph
Minjie Wang's avatar
Minjie Wang committed
1890
1891
1892
1893
        """
        induced_nodes = [utils.toindex(nodes.get(ntype, [])) for ntype in self.ntypes]
        sgi = self._graph.node_subgraph(induced_nodes)
        induced_edges = sgi.induced_edges
1894

Minjie Wang's avatar
Minjie Wang committed
1895
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
1896

Minjie Wang's avatar
Minjie Wang committed
1897
1898
    def edge_subgraph(self, edges, preserve_nodes=False):
        """Return the subgraph induced on given edges.
1899

Minjie Wang's avatar
Minjie Wang committed
1900
        The metagraph of the returned subgraph is the same as the parent graph.
1901

Minjie Wang's avatar
Minjie Wang committed
1902
        Features are copied from the original graph.
1903

Minjie Wang's avatar
Minjie Wang committed
1904
1905
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1906
1907
1908
1909
1910
1911
1912
1913
1914
        edges : dict[str->list or iterable]
            A dictionary mapping edge types to edge ID array for constructing
            subgraph. All edges must exist in the subgraph.

            The edge types are characterized by triplets of
            ``(src type, etype, dst type)``.
        preserve_nodes : bool
            Whether to preserve all nodes or not. If false, all nodes
            without edges will be removed. (Default: False)
1915

Minjie Wang's avatar
Minjie Wang committed
1916
1917
1918
1919
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1920
1921
1922

            The nodes and edges are relabeled using consecutive integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1923
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1924
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1925
            subgraph.
Mufei Li's avatar
Mufei Li committed
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.edge_subgraph({('user', 'follows', 'user'): [5, 6]})
        An error occurs as these edges do not exist.
        >>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): [1, 2],
        >>>                          ('user', 'plays', 'game'): [2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 1},
1947
              num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['plays'].edata[dgl.EID]   # Get the edge indices in the raw graph
        tensor([2])

        Get the copied node features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        subgraph
Minjie Wang's avatar
Minjie Wang committed
1971
1972
1973
1974
1975
1976
1977
        """
        edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
        induced_edges = [
            utils.toindex(edges.get(canonical_etype, []))
            for canonical_etype in self.canonical_etypes]
        sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
        induced_nodes = sgi.induced_nodes
1978

Minjie Wang's avatar
Minjie Wang committed
1979
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
1980

Minjie Wang's avatar
Minjie Wang committed
1981
1982
    def node_type_subgraph(self, ntypes):
        """Return the subgraph induced on given node types.
1983

Mufei Li's avatar
Mufei Li committed
1984
1985
        The metagraph of the returned subgraph is the subgraph of the original
        metagraph induced from the node types.
1986

Minjie Wang's avatar
Minjie Wang committed
1987
        Features are shared with the original graph.
1988

Minjie Wang's avatar
Minjie Wang committed
1989
1990
1991
1992
        Parameters
        ----------
        ntypes : list[str]
            The node types
1993

Minjie Wang's avatar
Minjie Wang committed
1994
1995
1996
1997
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.node_type_subgraph(['user'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}
              edata_schemes={})

        Get the shared node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        edge_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
2034
2035
2036
2037
2038
2039
        """
        rel_graphs = []
        meta_edges = []
        induced_etypes = []
        node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
        edge_frames = []
2040

2041
        num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
Minjie Wang's avatar
Minjie Wang committed
2042
2043
        ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
        srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
2044
        for i in range(len(self._etypes)):
Minjie Wang's avatar
Minjie Wang committed
2045
2046
            srctype = self._ntypes[srctype_id[i]]
            dsttype = self._ntypes[dsttype_id[i]]
2047

Minjie Wang's avatar
Minjie Wang committed
2048
2049
2050
2051
2052
            if srctype in ntypes and dsttype in ntypes:
                meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
                rel_graphs.append(self._graph.get_relation_graph(i))
                induced_etypes.append(self.etypes[i])
                edge_frames.append(self._edge_frames[i])
2053

2054
        metagraph = graph_index.from_edge_list(meta_edges, True)
2055
2056
        hgidx = heterograph_index.create_heterograph_from_relations(
            metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
Minjie Wang's avatar
Minjie Wang committed
2057
2058
        hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
        return hg
2059

Minjie Wang's avatar
Minjie Wang committed
2060
2061
    def edge_type_subgraph(self, etypes):
        """Return the subgraph induced on given edge types.
2062

Minjie Wang's avatar
Minjie Wang committed
2063
2064
        The metagraph of the returned subgraph is the subgraph of the original metagraph
        induced from the edge types.
2065

Minjie Wang's avatar
Minjie Wang committed
2066
        Features are shared with the original graph.
2067

Minjie Wang's avatar
Minjie Wang committed
2068
2069
2070
2071
        Parameters
        ----------
        etypes : list[str or tuple]
            The edge types
2072

Minjie Wang's avatar
Minjie Wang committed
2073
2074
2075
2076
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.edge_type_subgraph(['follows'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={}
              edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})

        Get the shared edge features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        node_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
2113
2114
2115
2116
2117
2118
        """
        etype_ids = [self.get_etype_id(etype) for etype in etypes]
        meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids))
        rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
        meta_src = meta_src.tonumpy()
        meta_dst = meta_dst.tonumpy()
2119
2120
2121
2122
        ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}
        mapped_meta_src = [ntypes_invmap[v] for v in meta_src]
        mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]
        node_frames = [self._node_frames[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
2123
        edge_frames = [self._edge_frames[i] for i in etype_ids]
2124
        induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
2125
        induced_etypes = [self._etypes[i] for i in etype_ids]   # get the "name" of edge type
2126
        num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
Minjie Wang's avatar
Minjie Wang committed
2127

2128
        metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
2129
2130
        hgidx = heterograph_index.create_heterograph_from_relations(
            metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
Minjie Wang's avatar
Minjie Wang committed
2131
2132
2133
        hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
        return hg

2134
    def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
Minjie Wang's avatar
Minjie Wang committed
2135
        """Return the adjacency matrix of edges of the given edge type.
2136

Minjie Wang's avatar
Minjie Wang committed
2137
2138
        By default, a row of returned adjacency matrix represents the
        destination of an edge and the column represents the source.
2139

Minjie Wang's avatar
Minjie Wang committed
2140
2141
        When transpose is True, a row represents the source and a column
        represents a destination.
2142
2143
2144

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2145
2146
2147
2148
2149
        transpose : bool, optional
            A flag to transpose the returned adjacency matrix. (Default: False)
        ctx : context, optional
            The context of returned adjacency matrix. (Default: cpu)
        scipy_fmt : str, optional
Minjie Wang's avatar
Minjie Wang committed
2150
            If specified, return a scipy sparse matrix in the given format.
Mufei Li's avatar
Mufei Li committed
2151
            Otherwise, return a backend dependent sparse tensor. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2152
2153
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2154
            in the graph. (Default: None)
2155

Minjie Wang's avatar
Minjie Wang committed
2156
2157
2158
2159
        Returns
        -------
        SparseTensor or scipy.sparse.spmatrix
            Adjacency matrix.
Mufei Li's avatar
Mufei Li committed
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182

        Examples
        --------

        Instantiate a heterogeneous graph.

        >>> follows_g = dgl.graph([(0, 0), (1, 1)], 'user', 'follows')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 2)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, devs_g])

        Get a backend dependent sparse tensor. Here we use PyTorch for example.

        >>> g.adjacency_matrix(etype='develops')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)

        Get a scipy coo sparse matrix.

        >>> g.adjacency_matrix(scipy_fmt='coo', etype='develops')
        <3x2 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in COOrdinate format>
2183
        """
Mufei Li's avatar
Mufei Li committed
2184
2185
2186
2187
2188
2189
2190
        if transpose is None:
            dgl_warning(
                "Currently adjacency_matrix() returns a matrix with destination as rows"
                " by default.  In 0.5 the result will have source as rows"
                " (i.e. transpose=True)")
            transpose = False

Minjie Wang's avatar
Minjie Wang committed
2191
2192
2193
2194
2195
        etid = self.get_etype_id(etype)
        if scipy_fmt is None:
            return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
        else:
            return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
2196

Minjie Wang's avatar
Minjie Wang committed
2197
2198
    # Alias of ``adjacency_matrix``
    adj = adjacency_matrix
2199

Minjie Wang's avatar
Minjie Wang committed
2200
2201
2202
    def incidence_matrix(self, typestr, ctx=F.cpu(), etype=None):
        """Return the incidence matrix representation of edges with the given
        edge type.
2203

Mufei Li's avatar
Mufei Li committed
2204
        An incidence matrix is an n-by-m sparse matrix, where n is
Minjie Wang's avatar
Minjie Wang committed
2205
2206
2207
        the number of nodes and m is the number of edges. Each nnz
        value indicating whether the edge is incident to the node
        or not.
2208

Mufei Li's avatar
Mufei Li committed
2209
        There are three types of incidence matrices :math:`I`:
Da Zheng's avatar
Da Zheng committed
2210

Minjie Wang's avatar
Minjie Wang committed
2211
        * ``in``:
Da Zheng's avatar
Da Zheng committed
2212

Minjie Wang's avatar
Minjie Wang committed
2213
2214
2215
            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
              (or :math:`v` is the dst node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.
Da Zheng's avatar
Da Zheng committed
2216

Minjie Wang's avatar
Minjie Wang committed
2217
        * ``out``:
Da Zheng's avatar
Da Zheng committed
2218

Minjie Wang's avatar
Minjie Wang committed
2219
2220
2221
2222
2223
2224
2225
2226
2227
            - :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
              (or :math:`v` is the src node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.

        * ``both`` (only if source and destination node type are the same):

            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
            - :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
            - :math:`I[v, e] = 0` otherwise (including self-loop).
Da Zheng's avatar
Da Zheng committed
2228
2229
2230

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2231
2232
        typestr : str
            Can be either ``in``, ``out`` or ``both``
Mufei Li's avatar
Mufei Li committed
2233
2234
        ctx : context, optional
            The context of returned incidence matrix. (Default: cpu)
Minjie Wang's avatar
Minjie Wang committed
2235
2236
2237
2238
2239
2240
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
2241
        Framework SparseTensor
Minjie Wang's avatar
Minjie Wang committed
2242
            The incidence matrix.
Mufei Li's avatar
Mufei Li committed
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262

        Examples
        --------

        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
        >>> g.incidence_matrix('in')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('out')
        tensor(indices=tensor([[0, 1],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('both')
        tensor(indices=tensor([[1, 2],
                               [1, 1]]),
               values=tensor([-1.,  1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
Da Zheng's avatar
Da Zheng committed
2263
        """
Minjie Wang's avatar
Minjie Wang committed
2264
2265
2266
2267
2268
        etid = self.get_etype_id(etype)
        return self._graph.incidence_matrix(etid, typestr, ctx)[0]

    # Alias of ``incidence_matrix``
    inc = incidence_matrix
Da Zheng's avatar
Da Zheng committed
2269

Minjie Wang's avatar
Minjie Wang committed
2270
2271
2272
2273
2274
    #################################################################
    # Features
    #################################################################

    def node_attr_schemes(self, ntype=None):
Mufei Li's avatar
Mufei Li committed
2275
        """Return the node feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
2276
2277

        Each feature scheme is a named tuple that stores the shape and data type
2278
        of the node feature.
Da Zheng's avatar
Da Zheng committed
2279
2280
2281

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2282
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
2283
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
2284
2285
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
2286
2287
2288
2289
2290

        Returns
        -------
        dict of str to schemes
            The schemes of node feature columns.
2291
2292
2293
2294
2295

        Examples
        --------
        The following uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
2296
        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
2297
        >>> g.nodes['user'].data['h'] = torch.randn(3, 4)
2298
2299
        >>> g.node_attr_schemes('user')
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
2300
2301
2302
2303

        See Also
        --------
        edge_attr_schemes
Da Zheng's avatar
Da Zheng committed
2304
        """
Minjie Wang's avatar
Minjie Wang committed
2305
        return self._node_frames[self.get_ntype_id(ntype)].schemes
Da Zheng's avatar
Da Zheng committed
2306

Minjie Wang's avatar
Minjie Wang committed
2307
    def edge_attr_schemes(self, etype=None):
Mufei Li's avatar
Mufei Li committed
2308
        """Return the edge feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
2309
2310

        Each feature scheme is a named tuple that stores the shape and data type
2311
        of the edge feature.
Da Zheng's avatar
Da Zheng committed
2312
2313
2314

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2315
2316
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2317
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2318
2319
2320
2321

        Returns
        -------
        dict of str to schemes
Mufei Li's avatar
Mufei Li committed
2322
            The schemes of edge feature columns.
Da Zheng's avatar
Da Zheng committed
2323

2324
2325
2326
        Examples
        --------
        The following uses PyTorch backend.
Da Zheng's avatar
Da Zheng committed
2327

Mufei Li's avatar
Mufei Li committed
2328
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
2329
        >>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
2330
2331
        >>> g.edge_attr_schemes(('user', 'plays', 'game'))
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
2332
2333
2334
2335

        See Also
        --------
        node_attr_schemes
Da Zheng's avatar
Da Zheng committed
2336
        """
Minjie Wang's avatar
Minjie Wang committed
2337
        return self._edge_frames[self.get_etype_id(etype)].schemes
Da Zheng's avatar
Da Zheng committed
2338

Minjie Wang's avatar
Minjie Wang committed
2339
2340
    def set_n_initializer(self, initializer, field=None, ntype=None):
        """Set the initializer for empty node features.
Da Zheng's avatar
Da Zheng committed
2341

Minjie Wang's avatar
Minjie Wang committed
2342
2343
        Initializer is a callable that returns a tensor given the shape, data type
        and device context.
Da Zheng's avatar
Da Zheng committed
2344

Minjie Wang's avatar
Minjie Wang committed
2345
        When a subset of the nodes are assigned a new feature, initializer is
Mufei Li's avatar
Mufei Li committed
2346
        used to create feature for the rest of the nodes.
Minjie Wang's avatar
Minjie Wang committed
2347
2348
2349
2350

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
2351
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
2352
        field : str, optional
Mufei Li's avatar
Mufei Li committed
2353
            The feature field name. Default is to set an initializer for all the
Minjie Wang's avatar
Minjie Wang committed
2354
2355
            feature fields.
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
2356
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
2357
2358
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
2359

Minjie Wang's avatar
Minjie Wang committed
2360
2361
2362
2363
        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Da Zheng's avatar
Da Zheng committed
2364

Mufei Li's avatar
Mufei Li committed
2365
2366
2367
        See Also
        --------
        set_e_initializer
Da Zheng's avatar
Da Zheng committed
2368
        """
Minjie Wang's avatar
Minjie Wang committed
2369
2370
        ntid = self.get_ntype_id(ntype)
        self._node_frames[ntid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2371

Minjie Wang's avatar
Minjie Wang committed
2372
2373
    def set_e_initializer(self, initializer, field=None, etype=None):
        """Set the initializer for empty edge features.
Da Zheng's avatar
Da Zheng committed
2374

Minjie Wang's avatar
Minjie Wang committed
2375
2376
2377
2378
2379
2380
2381
2382
2383
        Initializer is a callable that returns a tensor given the shape, data
        type and device context.

        When a subset of the edges are assigned a new feature, initializer is
        used to create feature for rest of the edges.

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
2384
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
2385
2386
2387
2388
2389
        field : str, optional
            The feature field name. Default is set an initializer for all the
            feature fields.
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2390
2391
            in the graph. Error will be raised otherwise.
            (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2392
2393
2394
2395
2396

        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Mufei Li's avatar
Mufei Li committed
2397
2398
2399
2400

        See Also
        --------
        set_n_initializer
Da Zheng's avatar
Da Zheng committed
2401
        """
Minjie Wang's avatar
Minjie Wang committed
2402
2403
        etid = self.get_etype_id(etype)
        self._edge_frames[etid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2404

Minjie Wang's avatar
Minjie Wang committed
2405
2406
    def _set_n_repr(self, ntid, u, data, inplace=False):
        """Internal API to set node features.
Da Zheng's avatar
Da Zheng committed
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
        and (D1, D2, ...) be the shape of the node representation tensor. The
        length of the given node ids must match B (i.e, len(u) == B).

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2418
2419
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2420
2421
        u : node, container or tensor
            The node(s).
Minjie Wang's avatar
Minjie Wang committed
2422
2423
2424
        data : dict of tensor
            Node representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2425
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2426
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2427
        """
2428
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2429
            num_nodes = self._graph.number_of_nodes(ntid)
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
        else:
            u = utils.toindex(u)
            num_nodes = len(u)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_nodes:
                raise DGLError('Expect number of features to match number of nodes (len(u)).'
                               ' Got %d and %d instead.' % (nfeats, num_nodes))

        if is_all(u):
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2441
                self._node_frames[ntid][key] = val
2442
        else:
Minjie Wang's avatar
Minjie Wang committed
2443
            self._node_frames[ntid].update_rows(u, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2444

Minjie Wang's avatar
Minjie Wang committed
2445
    def _get_n_repr(self, ntid, u):
Da Zheng's avatar
Da Zheng committed
2446
2447
2448
2449
2450
2451
        """Get node(s) representation of a single node type.

        The returned feature tensor batches multiple node features on the first dimension.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2452
2453
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2454
2455
2456
2457
2458
2459
2460
2461
        u : node, container or tensor
            The node(s).

        Returns
        -------
        dict
            Representation dict from feature name to feature tensor.
        """
2462
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2463
            return dict(self._node_frames[ntid])
2464
2465
        else:
            u = utils.toindex(u)
Minjie Wang's avatar
Minjie Wang committed
2466
            return self._node_frames[ntid].select_rows(u)
Da Zheng's avatar
Da Zheng committed
2467

Minjie Wang's avatar
Minjie Wang committed
2468
2469
    def _pop_n_repr(self, ntid, key):
        """Internal API to get and remove the specified node feature.
Da Zheng's avatar
Da Zheng committed
2470
2471
2472

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2473
2474
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2475
2476
2477
2478
2479
2480
2481
2482
        key : str
            The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2483
        return self._node_frames[ntid].pop(key)
Da Zheng's avatar
Da Zheng committed
2484

Minjie Wang's avatar
Minjie Wang committed
2485
2486
    def _set_e_repr(self, etid, edges, data, inplace=False):
        """Internal API to set edge(s) features.
Da Zheng's avatar
Da Zheng committed
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
        and (D1, D2, ...) be the shape of the edge representation tensor.

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2497
2498
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2499
2500
2501
2502
2503
2504
2505
2506
        edges : edges
            Edges can be either

            * A pair of endpoint nodes (u, v), where u is the node ID of source
              node type and v is that of destination node type.
            * A tensor of edge ids of the given type.

            The default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2507
2508
2509
        data : tensor or dict of tensor
            Edge representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2510
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2511
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2512
        """
2513
2514
2515
2516
2517
2518
2519
2520
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2521
            _, _, eid = self._graph.edge_ids(etid, u, v)
2522
2523
2524
2525
2526
2527
2528
2529
2530
        else:
            eid = utils.toindex(edges)

        # sanity check
        if not utils.is_dict_like(data):
            raise DGLError('Expect dictionary type for feature data.'
                           ' Got "%s" instead.' % type(data))

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2531
            num_edges = self._graph.number_of_edges(etid)
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
        else:
            eid = utils.toindex(eid)
            num_edges = len(eid)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_edges:
                raise DGLError('Expect number of features to match number of edges.'
                               ' Got %d and %d instead.' % (nfeats, num_edges))
        # set
        if is_all(eid):
            # update column
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2544
                self._edge_frames[etid][key] = val
2545
2546
        else:
            # update row
Minjie Wang's avatar
Minjie Wang committed
2547
            self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2548

Minjie Wang's avatar
Minjie Wang committed
2549
2550
    def _get_e_repr(self, etid, edges):
        """Internal API to get edge features.
Da Zheng's avatar
Da Zheng committed
2551
2552
2553

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2554
2555
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2556
2557
2558
2559
2560
2561
2562
2563
2564
        edges : edges
            Edges can be a pair of endpoint nodes (u, v), or a
            tensor of edge ids. The default value is all the edges.

        Returns
        -------
        dict
            Representation dict
        """
2565
2566
2567
2568
2569
2570
2571
2572
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2573
            _, _, eid = self._graph.edge_ids(etid, u, v)
2574
2575
2576
2577
        else:
            eid = utils.toindex(edges)

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2578
            return dict(self._edge_frames[etid])
2579
2580
        else:
            eid = utils.toindex(eid)
Minjie Wang's avatar
Minjie Wang committed
2581
            return self._edge_frames[etid].select_rows(eid)
Da Zheng's avatar
Da Zheng committed
2582

Minjie Wang's avatar
Minjie Wang committed
2583
    def _pop_e_repr(self, etid, key):
Da Zheng's avatar
Da Zheng committed
2584
2585
2586
2587
        """Get and remove the specified edge repr of a single edge type.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2588
2589
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2590
2591
2592
2593
2594
2595
2596
2597
        key : str
          The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2598
        self._edge_frames[etid].pop(key)
Da Zheng's avatar
Da Zheng committed
2599

Minjie Wang's avatar
Minjie Wang committed
2600
2601
2602
2603
2604
2605
2606
    #################################################################
    # Message passing
    #################################################################

    def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
        """Apply the function on the nodes with the same type to update their
        features.
Da Zheng's avatar
Da Zheng committed
2607

Minjie Wang's avatar
Minjie Wang committed
2608
        If None is provided for ``func``, nothing will happen.
Da Zheng's avatar
Da Zheng committed
2609
2610
2611

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2612
        func : callable or None
Minjie Wang's avatar
Minjie Wang committed
2613
2614
2615
            Apply function on the nodes. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        v : int or iterable of int or tensor, optional
Mufei Li's avatar
Mufei Li committed
2616
            The (type-specific) node (ids) on which to apply ``func``. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2617
2618
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
2619
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2620
2621
        inplace : bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2622
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2623
2624
2625

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2626
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
2627
2628
2629
        >>> g.nodes['user'].data['h'] = torch.ones(3, 5)
        >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
        >>> g.nodes['user'].data['h']
Da Zheng's avatar
Da Zheng committed
2630
2631
2632
        tensor([[2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2633
2634
2635
2636

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2637
        """
Minjie Wang's avatar
Minjie Wang committed
2638
2639
2640
2641
2642
2643
2644
        ntid = self.get_ntype_id(ntype)
        if is_all(v):
            v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
        else:
            v_ntype = utils.toindex(v)
        with ir.prog() as prog:
            scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
2645
                                           inplace=inplace, ntype=self._ntypes[ntid])
Minjie Wang's avatar
Minjie Wang committed
2646
2647
2648
            Runtime.run(prog)

    def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2649
2650
2651
2652
2653
2654
2655
        """Apply the function on the edges with the same type to update their
        features.

        If None is provided for ``func``, nothing will happen.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2656
        func : callable or None
Da Zheng's avatar
Da Zheng committed
2657
2658
            Apply function on the edge. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Mufei Li's avatar
Mufei Li committed
2659
        edges : optional
Da Zheng's avatar
Da Zheng committed
2660
            Edges on which to apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2661
            edge specification. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2662
2663
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2664
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2665
2666
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2667
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2668
2669
2670

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2671
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
2672
2673
2674
        >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
        >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
        >>> g.edges[('user', 'plays', 'game')].data['h']
Da Zheng's avatar
Da Zheng committed
2675
        tensor([[2., 2., 2., 2., 2.],
2676
                [2., 2., 2., 2., 2.],
Da Zheng's avatar
Da Zheng committed
2677
2678
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2679
2680
2681
2682
2683

        See Also
        --------
        apply_nodes
        group_apply_edges
Da Zheng's avatar
Da Zheng committed
2684
        """
Minjie Wang's avatar
Minjie Wang committed
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
            eid = utils.toindex(slice(0, self.number_of_edges(etype)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_apply_edges(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid, func, inplace=inplace)
            Runtime.run(prog)

    def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2707
2708
2709
2710
2711
2712
2713
        """Group the edges by nodes and apply the function of the grouped
        edges to update their features.  The edges are of the same edge type
        (hence having the same source and destination node type).

        Parameters
        ----------
        group_by : str
Mufei Li's avatar
Mufei Li committed
2714
            Specify how to group edges. Expected to be either ``'src'`` or ``'dst'``
Minjie Wang's avatar
Minjie Wang committed
2715
        func : callable
Mufei Li's avatar
Mufei Li committed
2716
2717
2718
2719
2720
            Apply function on the edge. The function should be an
            :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should be
            (bucket_size, degrees, *feature_shape), and return the dict
            with values of the same shapes.
        edges : optional
Da Zheng's avatar
Da Zheng committed
2721
            Edges on which to group and apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2722
            edge specification. Default is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2723
2724
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2725
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2726
2727
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
            (Default: False)

        Examples
        --------
        >>> g = dgl.graph([(0, 1), (0, 2), (1, 2)], 'user', 'follows')
        >>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
        >>> def softmax_feat(edges):
        >>>     return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
        >>> g.group_apply_edges(group_by='src', func=softmax_feat)
        >>> g.edata['norm_feat']
        tensor([[0.3796],
                [0.6204],
                [1.0000]])

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2745
        """
2746
2747
2748
        if group_by not in ('src', 'dst'):
            raise DGLError("Group_by should be either src or dst")

Minjie Wang's avatar
Minjie Wang committed
2749
2750
2751
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2752
            u, v, _ = self._graph.edges(etid, 'eid')
Minjie Wang's avatar
Minjie Wang committed
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
            eid = utils.toindex(slice(0, self.number_of_edges(etype)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_group_apply_edge(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                func, group_by,
                inplace=inplace)
            Runtime.run(prog)

    def send(self, edges, message_func, etype=None):
Da Zheng's avatar
Da Zheng committed
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
        """Send messages along the given edges with the same edge type.

        ``edges`` can be any of the following types:

        * ``int`` : Specify one edge using its edge id (of the given edge type).
        * ``pair of int`` : Specify one edge using its endpoints (of source node type
          and destination node type respectively).
        * ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids.
        * ``pair of int iterable`` / ``pair of tensors`` :
          Specify multiple edges using their endpoints.
2783

Mufei Li's avatar
Mufei Li committed
2784
        **Only works if the graph has one edge type.** For multiple types, use
2785
2786
2787
2788

        .. code::

           g['edgetype'].send(edges, message_func)
Da Zheng's avatar
Da Zheng committed
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798

        The UDF returns messages on the edges and can be later fetched in
        the destination node's ``mailbox``. Receiving will consume the messages.
        See :func:`recv` for example.

        If multiple ``send`` are triggered on the same edge without ``recv``. Messages
        generated by the later ``send`` will overwrite previous messages.

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2799
2800
        edges : optional
            Edges on which to apply ``message_func``.
2801
        message_func : callable
Da Zheng's avatar
Da Zheng committed
2802
2803
2804
2805
2806
2807
2808
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.

        Notes
        -----
        On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
        along all edges between :math:`u` and :math:`v`.
Mufei Li's avatar
Mufei Li committed
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829

        Examples
        --------

        >>> import dgl.function as fn
        >>> import torch
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Different ways for sending messages.

        >>> # Send the feature of source nodes along all edges
        >>> g.send(g.edges(), fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its id
        >>> g.send(0, fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its end points
        >>> g.send((0, 1), fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their ids
        >>> g.send([0, 1], fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their end points
        >>> g.send(([0, 1], [1, 2]), fn.copy_src('h', 'm'))
Da Zheng's avatar
Da Zheng committed
2830
        """
2831
        assert message_func is not None
Minjie Wang's avatar
Minjie Wang committed
2832
2833
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2834
2835

        if is_all(edges):
Minjie Wang's avatar
Minjie Wang committed
2836
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2837
            u, v, _ = self._graph.edges(etid, 'eid')
2838
2839
2840
2841
2842
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2843
            u, v, eid = self._graph.edge_ids(etid, u, v)
2844
2845
        else:
            eid = utils.toindex(edges)
Minjie Wang's avatar
Minjie Wang committed
2846
            u, v, _ = self._graph.find_edges(etid, eid)
2847
2848
2849
2850
2851
2852

        if len(eid) == 0:
            # no edge to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2853
2854
2855
2856
            scheduler.schedule_send(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                message_func)
2857
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2858
2859

    def recv(self,
Minjie Wang's avatar
Minjie Wang committed
2860
2861
             v,
             reduce_func,
Da Zheng's avatar
Da Zheng committed
2862
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
2863
             etype=None,
Da Zheng's avatar
Da Zheng committed
2864
             inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2865
        r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
Da Zheng's avatar
Da Zheng committed
2866

Minjie Wang's avatar
Minjie Wang committed
2867
2868
2869
        It calculates:

        .. math::
Mufei Li's avatar
Mufei Li committed
2870
            h_v^{new} = \sigma(f(\{m_{uv} | u\in\mathcal{N}_{t}(v)\}))
Minjie Wang's avatar
Minjie Wang committed
2871

Mufei Li's avatar
Mufei Li committed
2872
2873
        where :math:`\mathcal{N}_t(v)` defines the predecessors of node(s) :math:`v` connected by
        edges of type :math:`t`, and :math:`m_{uv}` is the message on edge :math:`(u,v)`.
Minjie Wang's avatar
Minjie Wang committed
2874

Mufei Li's avatar
Mufei Li committed
2875
2876
        * ``reduce_func`` specifies :math:`f`, e.g. summation or average.
        * ``apply_node_func`` specifies :math:`\sigma`, e.g. ReLU activation.
Minjie Wang's avatar
Minjie Wang committed
2877
2878

        Other notes:
Da Zheng's avatar
Da Zheng committed
2879
2880
2881
2882
2883
2884

        * `reduce_func` will be skipped for nodes with no incoming message.
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.
Minjie Wang's avatar
Minjie Wang committed
2885
2886
        * The node features will be updated by the result of the ``reduce_func``.
        * Messages are consumed once received.
Mufei Li's avatar
Mufei Li committed
2887
        * The provided UDF may be called multiple times so it is recommended to provide
Minjie Wang's avatar
Minjie Wang committed
2888
          function with no side effect.
2889

Da Zheng's avatar
Da Zheng committed
2890
2891
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2892
        v : int, container or tensor
Mufei Li's avatar
Mufei Li committed
2893
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
2894
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
2895
2896
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
2897
        apply_node_func : callable
Da Zheng's avatar
Da Zheng committed
2898
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2899
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2900
2901
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2902
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2903
2904
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
2929
        """
Minjie Wang's avatar
Minjie Wang committed
2930
2931
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2932
        if is_all(v):
Minjie Wang's avatar
Minjie Wang committed
2933
            v = F.arange(0, self.number_of_nodes(dtid))
2934
2935
2936
2937
2938
2939
2940
        elif isinstance(v, int):
            v = [v]
        v = utils.toindex(v)
        if len(v) == 0:
            # no vertex to be triggered.
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2941
2942
            scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v, reduce_func, apply_node_func,
2943
2944
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2945

Mufei Li's avatar
Mufei Li committed
2946
    def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2947
2948
2949
2950
2951
2952
        r"""Receive messages from multiple edge types and perform aggregation.

        It calculates:

        .. math::

Mufei Li's avatar
Mufei Li committed
2953
2954
2955
2956
            \begin{align}
            h_{v, t}^{new} &= f\left(\left\{m_{uv} | u\in\mathcal{N}_{t}(v)\right\}\right)\\
            h_v^{new} &= \sigma\left(g\left(\left\{h_{v, t}^{new} | t\in T_e\right\}\right)\right)
            \end{align}
Minjie Wang's avatar
Minjie Wang committed
2957

Mufei Li's avatar
Mufei Li committed
2958
2959
2960
2961
        * ``per_type_reducer`` is a dictionary mapping edge type (str or tuple of str) to
          reduce functions :math:`f` of each type.
        * ``cross_reducer`` specifies :math:`g`.
        * ``apply_node_func`` specifies :math:`\sigma`.
Minjie Wang's avatar
Minjie Wang committed
2962
2963
2964
2965
2966

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
2967
2968
        reducer_dict : dict of callable
            Mapping edge type (str or tuple of str) to reduce function (:mod:`Node UDF <dgl.udf>`).
Minjie Wang's avatar
Minjie Wang committed
2969
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
2970
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
2971
2972
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2973
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2974
2975
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'), {'follows': fn.sum('m', 'h'),
        >>>              'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
3002
3003
3004
        """
        # infer receive node type
        ntype = infer_ntype_from_dict(self, reducer_dict)
3005
        ntid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
        if is_all(v):
            v = F.arange(0, self.number_of_nodes(ntid))
        elif isinstance(v, int):
            v = [v]
        v = utils.toindex(v)
        if len(v) == 0:
            return
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
3016
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3017
3018
3019
3020
3021
3022
        with ir.prog() as prog:
            for ety, args in reducer_dict.items():
                outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
                args = pad_tuple(args, 2)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be either '
Mufei Li's avatar
Mufei Li committed
3023
                                   '(1) reduce_func or (2) (reduce_func, apply_node_func)')
Minjie Wang's avatar
Minjie Wang committed
3024
3025
3026
3027
3028
3029
3030
                rfunc, afunc = args
                etid = self.get_etype_id(ety)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3031
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3032
3033
            Runtime.run(prog)
        # merge by cross_reducer
3034
        self._node_frames[ntid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3035
        # apply
Mufei Li's avatar
Mufei Li committed
3036
3037
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3038

Da Zheng's avatar
Da Zheng committed
3039
3040
    def send_and_recv(self,
                      edges,
Minjie Wang's avatar
Minjie Wang committed
3041
3042
                      message_func,
                      reduce_func,
3043
                      apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3044
                      etype=None,
Da Zheng's avatar
Da Zheng committed
3045
                      inplace=False):
Mufei Li's avatar
Mufei Li committed
3046
        """Send messages along edges of the specified type, and let destinations
Da Zheng's avatar
Da Zheng committed
3047
3048
        receive them.

Mufei Li's avatar
Mufei Li committed
3049
        Optionally, apply a function to update the node features after "receive".
Da Zheng's avatar
Da Zheng committed
3050
3051

        This is a convenient combination for performing
Mufei Li's avatar
Mufei Li committed
3052
3053
        :mod:`send <dgl.DGLHeteroGraph.send>` along the ``edges`` and
        :mod:`recv <dgl.DGLHeteroGraph.recv>` for the destinations of the ``edges``.
Da Zheng's avatar
Da Zheng committed
3054

Mufei Li's avatar
Mufei Li committed
3055
        **Only works if the graph has one edge type.**  For multiple types, use
3056
3057
3058

        .. code::

Mufei Li's avatar
Mufei Li committed
3059
3060
           g['edgetype'].send_and_recv(edges, message_func, reduce_func,
                                       apply_node_func, inplace=inplace)
3061

Da Zheng's avatar
Da Zheng committed
3062
3063
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3064
3065
        edges : See :func:`send` for valid edge specification.
            Edges on which to apply ``func``.
Minjie Wang's avatar
Minjie Wang committed
3066
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3067
3068
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3069
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3070
3071
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3072
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3073
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3074
3075
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
3076
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3077
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3078
3079
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])

        Trigger "send" and "receive" separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])

        Trigger "send" and "receive" in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send_and_recv(g['follows'].edges(), fn.copy_src('h', 'm'),
        >>>                 fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
3114
        """
Minjie Wang's avatar
Minjie Wang committed
3115
3116
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3117
3118
3119
3120
3121
3122

        if isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
3123
            u, v, eid = self._graph.edge_ids(etid, u, v)
3124
3125
        else:
            eid = utils.toindex(edges)
Minjie Wang's avatar
Minjie Wang committed
3126
            u, v, _ = self._graph.find_edges(etid, eid)
3127
3128
3129
3130
3131
3132

        if len(u) == 0:
            # no edges to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3133
3134
3135
            scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                   (u, v, eid),
                                   message_func, reduce_func, apply_node_func,
3136
3137
                                   inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3138

Mufei Li's avatar
Mufei Li committed
3139
    def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
3140
3141
        r"""Send and receive messages along multiple edge types and perform aggregation.

Mufei Li's avatar
Mufei Li committed
3142
        Optionally, apply a function to update the node features after "receive".
Minjie Wang's avatar
Minjie Wang committed
3143

Mufei Li's avatar
Mufei Li committed
3144
3145
3146
        This is a convenient combination for performing multiple
        :mod:`send <dgl.DGLHeteroGraph.send>` along edges of different types and
        :mod:`multi_recv <dgl.DGLHeteroGraph.multi_recv>` for the destinations of all edges.
Minjie Wang's avatar
Minjie Wang committed
3147
3148
3149

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (4-tuples). Each 4-tuple represents
            (edges, msg_func, reduce_func, apply_node_func):

            * edges: See send() for valid edge specification.
                  Edges on which to pass messages.
            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the node. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3166
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3167
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3168
3169
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3170
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3171
3172
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Trigger send and recv separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'),
        >>>              {'follows': fn.sum('m', 'h'), 'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])

        Trigger “send” and “receive” in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_send_and_recv(
        >>>     {'follows': (g['follows'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (g['attracts'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
3211
3212
3213
        """
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
3214
        dtid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3215
3216
3217
3218
3219

        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
        all_vs = []
3220
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3221
3222
3223
3224
3225
3226
3227
3228
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 4)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3229
                                   '(edges, msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
                edges, mfunc, rfunc, afunc = args
                if isinstance(edges, tuple):
                    u, v = edges
                    u = utils.toindex(u)
                    v = utils.toindex(v)
                    # Rewrite u, v to handle edge broadcasting and multigraph.
                    u, v, eid = self._graph.edge_ids(etid, u, v)
                else:
                    eid = utils.toindex(edges)
                    u, v, _ = self._graph.find_edges(etid, eid)
                all_vs.append(v)
                if len(u) == 0:
                    # no edges to be triggered
                    continue
                scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                       (u, v, eid),
                                       mfunc, rfunc, afunc,
                                       inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3249
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3250
3251
            Runtime.run(prog)
        # merge by cross_reducer
3252
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3253
        # apply
Mufei Li's avatar
Mufei Li committed
3254
        if apply_node_func is not None:
Minjie Wang's avatar
Minjie Wang committed
3255
            dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
Mufei Li's avatar
Mufei Li committed
3256
            self.apply_nodes(apply_node_func, dstnodes, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3257

Da Zheng's avatar
Da Zheng committed
3258
3259
    def pull(self,
             v,
Minjie Wang's avatar
Minjie Wang committed
3260
3261
             message_func,
             reduce_func,
3262
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3263
             etype=None,
Da Zheng's avatar
Da Zheng committed
3264
3265
3266
3267
3268
             inplace=False):
        """Pull messages from the node(s)' predecessors and then update their features.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
3269
3270
3271
3272
3273
3274
        This is equivalent to :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>`
        on the incoming edges of ``v`` with the specified type.

        Other notes:

        * `reduce_func` will be skipped for nodes with no incoming messages.
Da Zheng's avatar
Da Zheng committed
3275
3276
3277
3278
3279
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.

Mufei Li's avatar
Mufei Li committed
3280
        **Only works if the graph has one edge type.** For multiple types, use
3281
3282
3283

        .. code::

Mufei Li's avatar
Mufei Li committed
3284
           g['edgetype'].pull(v, message_func, reduce_func, apply_node_func, inplace=inplace)
3285

Da Zheng's avatar
Da Zheng committed
3286
3287
        Parameters
        ----------
3288
        v : int, container or tensor, optional
Mufei Li's avatar
Mufei Li committed
3289
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
3290
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3291
3292
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3293
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3294
3295
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3296
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3297
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3298
3299
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
3300
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3301
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3302
3303
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Pull.

        >>> g['follows'].pull(2, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
3327
        """
Minjie Wang's avatar
Minjie Wang committed
3328
3329
3330
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3331
3332
3333
3334
3335

        v = utils.toindex(v)
        if len(v) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3336
3337
3338
            scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v,
                                    message_func, reduce_func, apply_node_func,
3339
3340
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3341

Mufei Li's avatar
Mufei Li committed
3342
    def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
3343
3344
3345
        r"""Pull and receive messages of the given nodes along multiple edge types
        and perform aggregation.

Mufei Li's avatar
Mufei Li committed
3346
3347
        This is equivalent to :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        on the incoming edges of ``v`` with the specified types.
Minjie Wang's avatar
Minjie Wang committed
3348
3349
3350
3351
3352

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3367
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3368
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3369
3370
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3371
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3372
3373
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3374
            (Default: False)
Minjie Wang's avatar
Minjie Wang committed
3375

Mufei Li's avatar
Mufei Li committed
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(1, 1), (1, 0)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Pull.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_pull(1,
        >>>              {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>               'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [3.]])
        """
Minjie Wang's avatar
Minjie Wang committed
3401
3402
3403
3404
3405
        v = utils.toindex(v)
        if len(v) == 0:
            return
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
3406
        dtid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3407
3408
3409
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
3410
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3411
3412
3413
3414
3415
3416
3417
3418
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3419
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3420
3421
3422
3423
3424
3425
                mfunc, rfunc, afunc = args
                scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v,
                                        mfunc, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3426
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3427
3428
            Runtime.run(prog)
        # merge by cross_reducer
3429
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3430
        # apply
Mufei Li's avatar
Mufei Li committed
3431
3432
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3433

Da Zheng's avatar
Da Zheng committed
3434
3435
    def push(self,
             u,
Minjie Wang's avatar
Minjie Wang committed
3436
3437
             message_func,
             reduce_func,
3438
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3439
             etype=None,
Da Zheng's avatar
Da Zheng committed
3440
3441
3442
             inplace=False):
        """Send message from the node(s) to their successors and update them.

Mufei Li's avatar
Mufei Li committed
3443
3444
3445
        This is equivalent to performing
        :mod:`send_and_recv <DGLHeteroGraph.send_and_recv>` along the outbound
        edges from ``u``.
Da Zheng's avatar
Da Zheng committed
3446

Mufei Li's avatar
Mufei Li committed
3447
        **Only works if the graph has one edge type.** For multiple types, use
3448
3449
3450

        .. code::

Mufei Li's avatar
Mufei Li committed
3451
           g['edgetype'].push(u, message_func, reduce_func, apply_node_func, inplace=inplace)
3452

Da Zheng's avatar
Da Zheng committed
3453
3454
        Parameters
        ----------
3455
        u : int, container or tensor
Mufei Li's avatar
Mufei Li committed
3456
            The node(s) to push out messages.
Minjie Wang's avatar
Minjie Wang committed
3457
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3458
3459
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3460
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3461
3462
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3463
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3464
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3465
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3466
3467
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3468
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3469
3470
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (0, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Push.

        >>> g['follows'].push(0, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [0.]])
Da Zheng's avatar
Da Zheng committed
3492
        """
Minjie Wang's avatar
Minjie Wang committed
3493
3494
3495
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3496
3497
3498
3499
3500

        u = utils.toindex(u)
        if len(u) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3501
3502
3503
            scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    u,
                                    message_func, reduce_func, apply_node_func,
3504
3505
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3506
3507

    def update_all(self,
Minjie Wang's avatar
Minjie Wang committed
3508
3509
3510
3511
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Da Zheng's avatar
Da Zheng committed
3512
3513
3514
3515
        """Send messages through all edges and update all nodes.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
3516
3517
3518
        This is equivalent to
        :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>` over all edges
        of the specified type.
Da Zheng's avatar
Da Zheng committed
3519

Mufei Li's avatar
Mufei Li committed
3520
        **Only works if the graph has one edge type.** For multiple types, use
3521
3522
3523
3524
3525

        .. code::

           g['edgetype'].update_all(message_func, reduce_func, apply_node_func)

Da Zheng's avatar
Da Zheng committed
3526
3527
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3528
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3529
3530
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3531
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3532
3533
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3534
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3535
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3536
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3537
3538
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (1, 2), (2, 2)], 'user', 'follows')

        Update all.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3559
        """
Minjie Wang's avatar
Minjie Wang committed
3560
3561
3562
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3563
3564

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3565
3566
3567
            scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                          message_func, reduce_func,
                                          apply_node_func)
3568
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3569

Mufei Li's avatar
Mufei Li committed
3570
    def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
Minjie Wang's avatar
Minjie Wang committed
3571
        r"""Send and receive messages along all edges.
Da Zheng's avatar
Da Zheng committed
3572

Mufei Li's avatar
Mufei Li committed
3573
3574
3575
        This is equivalent to
        :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        over all edges.
Da Zheng's avatar
Da Zheng committed
3576
3577
3578

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3593
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3594
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3595
3596
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3597
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3598
3599
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3600
            (Default: False)
Da Zheng's avatar
Da Zheng committed
3601

Mufei Li's avatar
Mufei Li committed
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
        etype_dict : dict of callable
            ``update_all`` arguments per edge type.

        Examples
        --------
        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Update all.

        >>> g.multi_update_all(
        >>>     {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [4.]])
        """
Minjie Wang's avatar
Minjie Wang committed
3629
3630
3631
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = defaultdict(list)
3632
        merge_order = defaultdict(list)
Minjie Wang's avatar
Minjie Wang committed
3633
3634
3635
3636
3637
3638
3639
3640
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3641
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3642
3643
3644
3645
3646
                mfunc, rfunc, afunc = args
                scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                              mfunc, rfunc, afunc,
                                              outframe=outframe)
                all_out[dtid].append(outframe)
3647
                merge_order[dtid].append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3648
3649
3650
            Runtime.run(prog)
        for dtid, frames in all_out.items():
            # merge by cross_reducer
3651
3652
            self._node_frames[dtid].update(
                merge_frames(frames, cross_reducer, merge_order[dtid]))
Minjie Wang's avatar
Minjie Wang committed
3653
            # apply
Mufei Li's avatar
Mufei Li committed
3654
3655
            if apply_node_func is not None:
                self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False)
Minjie Wang's avatar
Minjie Wang committed
3656
3657
3658
3659
3660
3661
3662

    def prop_nodes(self,
                   nodes_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3663
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3664
3665
3666
3667
3668
3669
        :func:`pull()` on nodes.

        The traversal order is specified by the ``nodes_generator``. It generates
        node frontiers, which is a list or a tensor of nodes. The nodes in the
        same frontier will be triggered together, while nodes in different frontiers
        will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3670
3671
3672

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3673
        nodes_generator : iterable, each element is a list or a tensor of node ids
Minjie Wang's avatar
Minjie Wang committed
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
            The generator of node frontiers. It specifies which nodes perform
            :func:`pull` at each timestep.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3684
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3685
3686
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3706

Minjie Wang's avatar
Minjie Wang committed
3707
3708
3709
        See Also
        --------
        prop_edges
Da Zheng's avatar
Da Zheng committed
3710
        """
Minjie Wang's avatar
Minjie Wang committed
3711
3712
        for node_frontier in nodes_generator:
            self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3713

Minjie Wang's avatar
Minjie Wang committed
3714
3715
3716
3717
3718
3719
    def prop_edges(self,
                   edges_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3720
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3721
        :func:`send_and_recv()` on edges.
Da Zheng's avatar
Da Zheng committed
3722

Minjie Wang's avatar
Minjie Wang committed
3723
3724
3725
        The traversal order is specified by the ``edges_generator``. It generates
        edge frontiers. The edge frontiers should be of *valid edges type*.
        See :func:`send` for more details.
Da Zheng's avatar
Da Zheng committed
3726

Mufei Li's avatar
Mufei Li committed
3727
        Edges in the same frontier will be triggered together, and edges in
Minjie Wang's avatar
Minjie Wang committed
3728
        different frontiers will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3729
3730
3731

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
        edges_generator : generator
            The generator of edge frontiers.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3742
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3743
3744
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3765

Minjie Wang's avatar
Minjie Wang committed
3766
3767
3768
        See Also
        --------
        prop_nodes
Da Zheng's avatar
Da Zheng committed
3769
        """
Minjie Wang's avatar
Minjie Wang committed
3770
3771
3772
        for edge_frontier in edges_generator:
            self.send_and_recv(edge_frontier, message_func, reduce_func,
                               apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3773

Minjie Wang's avatar
Minjie Wang committed
3774
3775
3776
    #################################################################
    # Misc
    #################################################################
Da Zheng's avatar
Da Zheng committed
3777

Minjie Wang's avatar
Minjie Wang committed
3778
3779
    def to_networkx(self, node_attrs=None, edge_attrs=None):
        """Convert this graph to networkx graph.
Da Zheng's avatar
Da Zheng committed
3780

Minjie Wang's avatar
Minjie Wang committed
3781
        The edge id will be saved as the 'id' edge attribute.
Da Zheng's avatar
Da Zheng committed
3782
3783
3784

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3785
3786
3787
3788
        node_attrs : iterable of str, optional
            The node attributes to be copied.
        edge_attrs : iterable of str, optional
            The edge attributes to be copied.
Da Zheng's avatar
Da Zheng committed
3789
3790
3791

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
3792
3793
        networkx.DiGraph
            The nx graph
Da Zheng's avatar
Da Zheng committed
3794

Minjie Wang's avatar
Minjie Wang committed
3795
3796
        Examples
        --------
Da Zheng's avatar
Da Zheng committed
3797

Minjie Wang's avatar
Minjie Wang committed
3798
3799
3800
        .. note:: Here we use pytorch syntax for demo. The general idea applies
            to other frameworks with minor syntax change (e.g. replace
            ``torch.tensor`` with ``mxnet.ndarray``).
Da Zheng's avatar
Da Zheng committed
3801

Minjie Wang's avatar
Minjie Wang committed
3802
3803
3804
3805
3806
        >>> import torch as th
        >>> g = DGLGraph()
        >>> g.add_nodes(5, {'n1': th.randn(5, 10)})
        >>> g.add_edges([0,1,3,4], [2,4,0,3], {'e1': th.randn(4, 6)})
        >>> nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
Da Zheng's avatar
Da Zheng committed
3807

Minjie Wang's avatar
Minjie Wang committed
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
        See Also
        --------
        dgl.to_networkx
        """
        # TODO(minjie): multi-type support
        assert len(self.ntypes) == 1
        assert len(self.etypes) == 1
        src, dst = self.edges()
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
3818
3819
        # xiangsx: Always treat graph as multigraph
        nx_graph = nx.MultiDiGraph()
Minjie Wang's avatar
Minjie Wang committed
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
        nx_graph.add_nodes_from(range(self.number_of_nodes()))
        for eid, (u, v) in enumerate(zip(src, dst)):
            nx_graph.add_edge(u, v, id=eid)

        if node_attrs is not None:
            for nid, attr in nx_graph.nodes(data=True):
                feat_dict = self._get_n_repr(0, nid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
        if edge_attrs is not None:
            for _, _, attr in nx_graph.edges(data=True):
                eid = attr['id']
                feat_dict = self._get_e_repr(0, eid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
        return nx_graph

    def filter_nodes(self, predicate, nodes=ALL, ntype=None):
Da Zheng's avatar
Da Zheng committed
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
        """Return a tensor of node IDs with the given node type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(nodes) -> tensor``.
            ``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding node in
            the batch satisfies the predicate.
        nodes : int, iterable or tensor of ints
            The nodes to filter on. Default value is all the nodes.
Minjie Wang's avatar
Minjie Wang committed
3849
3850
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
3851
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3852
3853
3854
3855

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3856
3857
3858
3859
3860
3861
3862
            Node ids indicating the nodes that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
3863
        >>> g = dgl.graph([], 'user', 'follows', num_nodes=4)
Mufei Li's avatar
Mufei Li committed
3864
3865
3866
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3867
        """
Minjie Wang's avatar
Minjie Wang committed
3868
3869
3870
3871
3872
3873
3874
        ntid = self.get_ntype_id(ntype)
        if is_all(nodes):
            v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)))
        else:
            v = utils.toindex(nodes)

        n_repr = self._get_n_repr(ntid, v)
3875
        nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid])
Minjie Wang's avatar
Minjie Wang committed
3876
        n_mask = F.copy_to(predicate(nbatch), F.cpu())
Da Zheng's avatar
Da Zheng committed
3877

Minjie Wang's avatar
Minjie Wang committed
3878
3879
3880
3881
3882
3883
3884
        if is_all(nodes):
            return F.nonzero_1d(n_mask)
        else:
            nodes = F.tensor(nodes)
            return F.boolean_mask(nodes, n_mask)

    def filter_edges(self, predicate, edges=ALL, etype=None):
Da Zheng's avatar
Da Zheng committed
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
        """Return a tensor of edge IDs with the given edge type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(edges) -> tensor``.
            ``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding edge in
            the batch satisfies the predicate.
        edges : valid edges type
            Edges on which to apply ``func``. See :func:`send` for valid
            edges type. Default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
3899
3900
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3901
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3902
3903
3904
3905

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
            Edge ids indicating the edges that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
        >>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows')
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3917
        """
Minjie Wang's avatar
Minjie Wang committed
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)
Da Zheng's avatar
Da Zheng committed
3932

Minjie Wang's avatar
Minjie Wang committed
3933
3934
3935
        src_data = self._get_n_repr(stid, u)
        edge_data = self._get_e_repr(etid, eid)
        dst_data = self._get_n_repr(dtid, v)
3936
3937
        ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data,
                           canonical_etype=self.canonical_etypes[etid])
Minjie Wang's avatar
Minjie Wang committed
3938
3939
3940
3941
3942
3943
3944
3945
3946
3947
3948
        e_mask = F.copy_to(predicate(ebatch), F.cpu())

        if is_all(edges):
            return F.nonzero_1d(e_mask)
        else:
            edges = F.tensor(edges)
            return F.boolean_mask(edges, e_mask)

    def to(self, ctx):  # pylint: disable=invalid-name
        """Move both ndata and edata to the targeted mode (cpu/gpu)
        Framework agnostic
Da Zheng's avatar
Da Zheng committed
3949
3950
3951

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3952
3953
3954
        ctx : framework-specific context object
            The context to move data to.

3955
3956
3957
3958
3959
        Returns
        -------
        g : DGLHeteroGraph
          Moved DGLHeteroGraph of the targeted mode.

Minjie Wang's avatar
Minjie Wang committed
3960
3961
3962
3963
3964
        Examples
        --------
        The following example uses PyTorch backend.

        >>> import torch
Mufei Li's avatar
Mufei Li committed
3965
3966
3967
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
3968
        >>> g = g.to(torch.device('cuda:0'))
Da Zheng's avatar
Da Zheng committed
3969
        """
Minjie Wang's avatar
Minjie Wang committed
3970
3971
3972
3973
3974
3975
        for i in range(len(self._node_frames)):
            for k in self._node_frames[i].keys():
                self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
        for i in range(len(self._edge_frames)):
            for k in self._edge_frames[i].keys():
                self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
3976
        return self
Da Zheng's avatar
Da Zheng committed
3977

Minjie Wang's avatar
Minjie Wang committed
3978
    def local_var(self):
Mufei Li's avatar
Mufei Li committed
3979
        """Return a heterograph object that can be used in a local function scope.
Minjie Wang's avatar
Minjie Wang committed
3980
3981
3982
3983
3984
3985
3986
3987

        The returned graph object shares the feature data and graph structure of this graph.
        However, any out-place mutation to the feature data will not reflect to this graph,
        thus making it easier to use in a function scope.

        If set, the local graph object will use same initializers for node features and
        edge features.

Mufei Li's avatar
Mufei Li committed
3988
3989
3990
3991
3992
3993
3994
3995
3996
3997
3998
3999
4000
        Returns
        -------
        DGLHeteroGraph
            The graph object that can be used as a local variable.

        Notes
        -----
        Internally, the returned graph shares the same feature tensors, but construct a new
        dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
        graph will not reflect to the original graph. However, inplace operations do change
        the shared tensor values, so will be reflected to the original graph. This function
        also has little overhead when the number of feature tensors in this graph is small.

Minjie Wang's avatar
Minjie Wang committed
4001
4002
4003
4004
4005
4006
4007
4008
4009
        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
4010
4011
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4012
        >>>
Mufei Li's avatar
Mufei Li committed
4013
4014
4015
4016
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
4017
4018
4019
4020
4021
4022

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
4023
4024
4025
        >>>     # This 'h' feature will stay local and be GCed when the function exits
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4026
        >>>
Mufei Li's avatar
Mufei Li committed
4027
4028
4029
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
4064
4065
4066
4067
        False

        See Also
        --------
        local_var
        """
        local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
        local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
        # Use same per-column initializers and default initializer.
        # If registered, a column (based on key) initializer will be used first,
        # otherwise the default initializer will be used.
        for fr1, fr2 in zip(local_node_frames, self._node_frames):
            sync_frame_initializer(fr1._frame, fr2._frame)
        for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
            sync_frame_initializer(fr1._frame, fr2._frame)
        return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
                              local_node_frames,
                              local_edge_frames)

    @contextmanager
    def local_scope(self):
        """Enter a local scope context for this graph.

        By entering a local scope, any out-place mutation to the feature data will
        not reflect to the original graph, thus making it easier to use in a function scope.

        If set, the local scope will use same initializers for node features and
        edge features.

        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
4068
4069
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4070
        >>>
Mufei Li's avatar
Mufei Li committed
4071
4072
4073
4074
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
4075
4076
4077
4078
4079
4080

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
4081
4082
4083
        >>>         # This 'h' feature will stay local and be GCed when the function exits
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4084
        >>>
Mufei Li's avatar
Mufei Li committed
4085
4086
4087
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
        False

        See Also
        --------
        local_var
        """
        old_nframes = self._node_frames
        old_eframes = self._edge_frames
        self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
        self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
        # Use same per-column initializers and default initializer.
        # If registered, a column (based on key) initializer will be used first,
        # otherwise the default initializer will be used.
        for fr1, fr2 in zip(self._node_frames, old_nframes):
            sync_frame_initializer(fr1._frame, fr2._frame)
        for fr1, fr2 in zip(self._edge_frames, old_eframes):
            sync_frame_initializer(fr1._frame, fr2._frame)
        yield
        self._node_frames = old_nframes
        self._edge_frames = old_eframes

4109
4110
4111
4112
    def is_homograph(self):
        """Return if the graph is homogeneous."""
        return len(self.ntypes) == 1 and len(self.etypes) == 1

Minjie Wang's avatar
Minjie Wang committed
4113
4114
4115
4116
4117
4118
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
############################################################
# Internal APIs
############################################################

def make_canonical_etypes(etypes, ntypes, metagraph):
    """Internal function to convert etype name to (srctype, etype, dsttype)

    Parameters
    ----------
    etypes : list of str
        Edge type list
    ntypes : list of str
        Node type list
    metagraph : GraphIndex
        Meta graph.

    Returns
    -------
    list of tuples (srctype, etype, dsttype)
    """
    # sanity check
    if len(etypes) != metagraph.number_of_edges():
        raise DGLError('Length of edge type list must match the number of '
                       'edges in the metagraph. {} vs {}'.format(
                           len(etypes), metagraph.number_of_edges()))
    if len(ntypes) != metagraph.number_of_nodes():
        raise DGLError('Length of nodes type list must match the number of '
                       'nodes in the metagraph. {} vs {}'.format(
                           len(ntypes), metagraph.number_of_nodes()))
    src, dst, eid = metagraph.edges()
    rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
    return rst

4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170
4171
4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
4194
4195
4196
4197
def is_unibipartite(graph):
    """Internal function that returns whether the given graph is a uni-directional
    bipartite graph.

    Parameters
    ----------
    graph : GraphIndex
        Input graph

    Returns
    -------
    bool
        True if the graph is a uni-bipartite.
    """
    src, dst, _ = graph.edges()
    return set(src.tonumpy()).isdisjoint(set(dst.tonumpy()))

def find_src_dst_ntypes(ntypes, metagraph):
    """Internal function to split ntypes into SRC and DST categories.

    If the metagraph is not a uni-bipartite graph (so that the SRC and DST categories
    are not well-defined), return None.

    For node types that are isolated (i.e, no relation is associated with it), they
    are assigned to the SRC category.

    Parameters
    ----------
    ntypes : list of str
        Node type list
    metagraph : GraphIndex
        Meta graph.

    Returns
    -------
    (dict[int, str], dict[int, str]) or None
        Node types belonging to SRC and DST categories. Types are stored in
        a dictionary from type name to type id. Return None if the graph is
        not uni-bipartite.
    """
    src, dst, _ = metagraph.edges()
    if set(src.tonumpy()).isdisjoint(set(dst.tonumpy())):
        srctypes = {ntypes[tid] : tid for tid in src}
        dsttypes = {ntypes[tid] : tid for tid in dst}
        # handle isolated node types
        for ntid, ntype in enumerate(ntypes):
            if ntype not in srctypes and ntype not in dsttypes:
                srctypes[ntype] = ntid
        return srctypes, dsttypes
    else:
        return None

Minjie Wang's avatar
Minjie Wang committed
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
def infer_ntype_from_dict(graph, etype_dict):
    """Infer node type from dictionary of edge type to values.

    All the edge types in the dict must share the same destination node type
    and the node type will be returned. Otherwise, throw error.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    etype_dict : dict
        Dictionary whose key is edge type

    Returns
    -------
    str
        Node type
    """
    ntype = None
    for ety in etype_dict:
        _, _, dty = graph.to_canonical_etype(ety)
        if ntype is None:
            ntype = dty
        if ntype != dty:
            raise DGLError("Cannot infer destination node type from the dictionary. "
                           "A valid specification must make sure that all the edge "
                           "type keys share the same destination node type.")
    return ntype

def pad_tuple(tup, length, pad_val=None):
    """Pad the given tuple to the given length.

    If the input is not a tuple, convert it to a tuple of length one.
    Return None if pad fails.
    """
    if not isinstance(tup, tuple):
        tup = (tup, )
    if len(tup) > length:
        return None
    elif len(tup) == length:
        return tup
    else:
        return tup + (pad_val,) * (length - len(tup))

4242
def merge_frames(frames, reducer, order=None):
Minjie Wang's avatar
Minjie Wang committed
4243
4244
4245
4246
    """Merge input frames into one. Resolve conflict fields using reducer.

    Parameters
    ----------
4247
    frames : list[FrameRef]
Minjie Wang's avatar
Minjie Wang committed
4248
4249
4250
        Input frames
    reducer : str
        One of "sum", "max", "min", "mean", "stack"
4251
4252
4253
4254
4255
4256
    order : list[Int], optional
        Merge order hint. Useful for "stack" reducer.
        If provided, each integer indicates the relative order
        of the ``frames`` list. Frames are sorted according to this list
        in ascending order. Tie is not handled so make sure the order values
        are distinct.
Minjie Wang's avatar
Minjie Wang committed
4257
4258
4259
4260
4261

    Returns
    -------
    FrameRef
        Merged frame
Da Zheng's avatar
Da Zheng committed
4262
    """
4263
4264
4265
    if len(frames) == 1 and reducer != 'stack':
        # Directly return the only one input. Stack reducer requires
        # modifying tensor shape.
Minjie Wang's avatar
Minjie Wang committed
4266
4267
        return frames[0]
    if reducer == 'stack':
4268
4269
4270
4271
4272
        # Stack order does not matter. However, it must be consistent!
        if order:
            assert len(order) == len(frames)
            sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
            frames = list(zip(*sorted_with_key))[0]
Minjie Wang's avatar
Minjie Wang committed
4273
4274
4275
4276
4277
4278
4279
4280
        def merger(flist):
            return F.stack(flist, 1)
    else:
        redfn = getattr(F, reducer, None)
        if redfn is None:
            raise DGLError('Invalid cross type reducer. Must be one of '
                           '"sum", "max", "min", "mean" or "stack".')
        def merger(flist):
4281
            return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
Minjie Wang's avatar
Minjie Wang committed
4282
4283
4284
4285
4286
4287
4288
4289
4290
    ret = FrameRef(frame_like(frames[0]._frame))
    keys = set()
    for frm in frames:
        keys.update(frm.keys())
    for k in keys:
        flist = []
        for frm in frames:
            if k in frm:
                flist.append(frm[k])
4291
        ret[k] = merger(flist)
Minjie Wang's avatar
Minjie Wang committed
4292
4293
4294
4295
4296
4297
4298
    return ret

def combine_frames(frames, ids):
    """Merge the frames into one frame, taking the common columns.

    Return None if there is no common columns.

Da Zheng's avatar
Da Zheng committed
4299
4300
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
4301
4302
4303
4304
4305
4306
4307
4308
4309
    frames : List[FrameRef]
        List of frames
    ids : List[int]
        List of frame IDs

    Returns
    -------
    FrameRef
        The resulting frame
Da Zheng's avatar
Da Zheng committed
4310
    """
Minjie Wang's avatar
Minjie Wang committed
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
    # find common columns and check if their schemes match
    schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
    for frame_id in ids:
        frame = frames[frame_id]
        for key, scheme in list(schemes.items()):
            if key in frame.schemes:
                if frame.schemes[key] != scheme:
                    raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
                                   (key, frame.schemes[key], scheme))
            else:
                del schemes[key]

    if len(schemes) == 0:
        return None

    # concatenate the columns
    to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
    cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
    return FrameRef(Frame(cols))

def combine_names(names, ids=None):
    """Combine the selected names into one new name.

    Parameters
    ----------
    names : list of str
        String names
    ids : numpy.ndarray, optional
        Selected index

    Returns
    -------
    str
    """
    if ids is None:
        return '+'.join(sorted(names))
    else:
        selected = sorted([names[i] for i in ids])
        return '+'.join(selected)

class AdaptedHeteroGraph(GraphAdapter):
    """Adapt DGLGraph to interface required by scheduler.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    stid : int
        Source node type id
    dtid : int
        Destination node type id
    etid : int
        Edge type id
    """
    def __init__(self, graph, stid, dtid, etid):
        self.graph = graph
        self.stid = stid
        self.dtid = dtid
        self.etid = etid
Da Zheng's avatar
Da Zheng committed
4370
4371

    @property
Minjie Wang's avatar
Minjie Wang committed
4372
4373
    def gidx(self):
        return self.graph._graph
Da Zheng's avatar
Da Zheng committed
4374

Minjie Wang's avatar
Minjie Wang committed
4375
4376
4377
    def num_src(self):
        """Number of source nodes."""
        return self.graph._graph.number_of_nodes(self.stid)
Da Zheng's avatar
Da Zheng committed
4378

Minjie Wang's avatar
Minjie Wang committed
4379
4380
4381
4382
4383
4384
4385
    def num_dst(self):
        """Number of destination nodes."""
        return self.graph._graph.number_of_nodes(self.dtid)

    def num_edges(self):
        """Number of edges."""
        return self.graph._graph.number_of_edges(self.etid)
Da Zheng's avatar
Da Zheng committed
4386
4387

    @property
Minjie Wang's avatar
Minjie Wang committed
4388
4389
4390
    def srcframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.stid]
Da Zheng's avatar
Da Zheng committed
4391

Minjie Wang's avatar
Minjie Wang committed
4392
4393
4394
4395
    @property
    def dstframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.dtid]
Da Zheng's avatar
Da Zheng committed
4396

Minjie Wang's avatar
Minjie Wang committed
4397
4398
4399
4400
    @property
    def edgeframe(self):
        """Frame to store edge features."""
        return self.graph._edge_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
4401

Minjie Wang's avatar
Minjie Wang committed
4402
4403
4404
4405
    @property
    def msgframe(self):
        """Frame to store messages."""
        return self.graph._msg_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
4406

Minjie Wang's avatar
Minjie Wang committed
4407
4408
4409
4410
    @property
    def msgindicator(self):
        """Message indicator tensor."""
        return self.graph._get_msg_index(self.etid)
Da Zheng's avatar
Da Zheng committed
4411

Minjie Wang's avatar
Minjie Wang committed
4412
4413
4414
4415
    @msgindicator.setter
    def msgindicator(self, val):
        """Set new message indicator tensor."""
        self.graph._set_msg_index(self.etid, val)
Da Zheng's avatar
Da Zheng committed
4416

Minjie Wang's avatar
Minjie Wang committed
4417
4418
    def in_edges(self, nodes):
        return self.graph._graph.in_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
4419

Minjie Wang's avatar
Minjie Wang committed
4420
4421
    def out_edges(self, nodes):
        return self.graph._graph.out_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
4422

Minjie Wang's avatar
Minjie Wang committed
4423
4424
    def edges(self, form):
        return self.graph._graph.edges(self.etid, form)
Da Zheng's avatar
Da Zheng committed
4425

Minjie Wang's avatar
Minjie Wang committed
4426
4427
4428
4429
4430
    def get_immutable_gidx(self, ctx):
        return self.graph._graph.get_unitgraph(self.etid, ctx)

    def bits_needed(self):
        return self.graph._graph.bits_needed(self.etid)
4431
4432
4433
4434
4435

    @property
    def canonical_etype(self):
        """Canonical edge type."""
        return self.graph.canonical_etypes[self.etid]