heterograph.py 163 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
"""Classes for heterogeneous graphs."""
2
#pylint: disable= too-many-lines
3
from collections import defaultdict
Minjie Wang's avatar
Minjie Wang committed
4
from contextlib import contextmanager
5
import copy
6
import networkx as nx
Minjie Wang's avatar
Minjie Wang committed
7
8
9
10
import numpy as np

from . import graph_index
from . import heterograph_index
11
12
13
from . import utils
from . import backend as F
from . import init
Minjie Wang's avatar
Minjie Wang committed
14
from .runtime import ir, scheduler, Runtime, GraphAdapter
15
from .frame import Frame, FrameRef, frame_like
16
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
Mufei Li's avatar
Mufei Li committed
17
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
Mufei Li's avatar
Mufei Li committed
18
from .udf import NodeBatch, EdgeBatch
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
23
24

__all__ = ['DGLHeteroGraph', 'combine_names']

class DGLHeteroGraph(object):
    """Base heterogeneous graph class.

Mufei Li's avatar
Mufei Li committed
25
26
    **Do NOT instantiate from this class directly; use** :mod:`conversion methods
    <dgl.convert>` **instead.**
Minjie Wang's avatar
Minjie Wang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    A Heterogeneous graph is defined as a graph with node types and edge
    types.

    If two edges share the same edge type, then their source nodes, as well
    as their destination nodes, also have the same type (the source node
    types don't have to be the same as the destination node types).

    Examples
    --------
    Suppose that we want to construct the following heterogeneous graph:

    .. graphviz::

       digraph G {
           Alice -> Bob [label=follows]
           Bob -> Carol [label=follows]
           Alice -> Tetris [label=plays]
           Bob -> Tetris [label=plays]
           Bob -> Minecraft [label=plays]
           Carol -> Minecraft [label=plays]
           Nintendo -> Tetris [label=develops]
           Mojang -> Minecraft [label=develops]
           {rank=source; Alice; Bob; Carol}
           {rank=sink; Nintendo; Mojang}
       }

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
54
    And suppose that one maps the users, games and developers to the following
Minjie Wang's avatar
Minjie Wang committed
55
56
    IDs:

Mufei Li's avatar
Mufei Li committed
57
58
59
60
61
    =========  =====  ===  =====
    User name  Alice  Bob  Carol
    =========  =====  ===  =====
    User ID    0      1    2
    =========  =====  ===  =====
Minjie Wang's avatar
Minjie Wang committed
62

Mufei Li's avatar
Mufei Li committed
63
64
65
66
67
    =========  ======  =========
    Game name  Tetris  Minecraft
    =========  ======  =========
    Game ID    0       1
    =========  ======  =========
Minjie Wang's avatar
Minjie Wang committed
68

Mufei Li's avatar
Mufei Li committed
69
70
71
72
73
    ==============  ========  ======
    Developer name  Nintendo  Mojang
    ==============  ========  ======
    Developer ID    0         1
    ==============  ========  ======
Minjie Wang's avatar
Minjie Wang committed
74
75
76
77
78

    One can construct the graph as follows:

    >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
Mufei Li's avatar
Mufei Li committed
79
80
    >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    >>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
Minjie Wang's avatar
Minjie Wang committed
81

82
83
84
85
86
87
88
89
    Or equivalently

    >>> g = dgl.heterograph({
    ...     ('user', 'follows', 'user'): [(0, 1), (1, 2)],
    ...     ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
    ...     ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
    ...     })

Minjie Wang's avatar
Minjie Wang committed
90
    :func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
Mufei Li's avatar
Mufei Li committed
91
92
93
94
95
96
97
98
    data types including:

    * edge list
    * edge tuples
    * networkx graph
    * scipy sparse matrix

    Click the function names for more details.
Minjie Wang's avatar
Minjie Wang committed
99
100
101
102
103
104
105
106
107
108
109
110
111

    Then one can query the graph structure by specifying the ``ntype`` or ``etype`` arguments:

    >>> g.number_of_nodes('user')
    3
    >>> g.number_of_edges('plays')
    4
    >>> g.out_degrees(etype='develops')  # out-degrees of source nodes of 'develops' relation
    tensor([1, 1])
    >>> g.in_edges(0, etype='develops')  # in-edges of destination node 0 of 'develops' relation
    (tensor([0]), tensor([0]))

    Or on the sliced graph for an edge type:
112

Minjie Wang's avatar
Minjie Wang committed
113
114
115
116
117
    >>> g['plays'].number_of_edges()
    4
    >>> g['develops'].out_degrees()
    tensor([1, 1])
    >>> g['develops'].in_edges(0)
Mufei Li's avatar
Mufei Li committed
118
    (tensor([0]), tensor([0]))
Minjie Wang's avatar
Minjie Wang committed
119
120
121
122
123
124
125
126
127
128

    Node type names must be distinct (no two types have the same name). Edge types could
    have the same name but they must be distinguishable by the ``(src_type, edge_type, dst_type)``
    triplet (called *canonical edge type*).

    For example, suppose a graph that has two types of relation "user-watches-movie"
    and "user-watches-TV" as follows:

    >>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
    >>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
Mufei Li's avatar
Mufei Li committed
129
    >>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs
Minjie Wang's avatar
Minjie Wang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    To distinguish between the two "watches" edge type, one must specify a full triplet:

    >>> GG.number_of_edges(('user', 'watches', 'movie'))
    3
    >>> GG.number_of_edges(('user', 'watches', 'TV'))
    2
    >>> GG['user', 'watches', 'movie'].out_degrees()
    tensor([1, 2])

    Using only one single edge type string "watches" is ambiguous and will cause error:

    >>> GG.number_of_edges('watches')  # AMBIGUOUS!!

    In many cases, there is only one type of nodes or one type of edges, and the ``ntype``
    and ``etype`` argument could be omitted. This is very common when using the sliced
    graph, which usually contains only one edge type, and sometimes only one node type:
147

Minjie Wang's avatar
Minjie Wang committed
148
149
150
151
    >>> g['follows'].number_of_nodes()  # OK!! because g['follows'] only has one node type 'user'
    3
    >>> g['plays'].number_of_nodes()  # ERROR!! There are two types 'user' and 'game'.
    >>> g['plays'].number_of_edges()  # OK!! because there is only one edge type 'plays'
Da Zheng's avatar
Da Zheng committed
152

153
154
    TODO(minjie): docstring about uni-directional bipartite graph

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    Metagraph
    ---------
    For each heterogeneous graph, one can often infer the *metagraph*, the template of
    edge connections showing how many types of nodes and edges exist in the graph, and
    how each edge type could connect between node types.

    One can analyze the example gameplay graph above and figure out the metagraph as
    follows:

    .. graphviz::

       digraph G {
           User -> User [label=follows]
           User -> Game [label=plays]
           Developer -> Game [label=develops]
       }


Da Zheng's avatar
Da Zheng committed
173
174
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
175
176
    gidx : HeteroGraphIndex
        Graph index object.
177
    ntypes : list of str, pair of list of str
Mufei Li's avatar
Mufei Li committed
178
        Node type list. ``ntypes[i]`` stores the name of node type i.
179
180
        If a pair is given, the graph created is a uni-directional bipartite graph,
        and its SRC node types and DST node types are given as in the pair.
Minjie Wang's avatar
Minjie Wang committed
181
    etypes : list of str
Mufei Li's avatar
Mufei Li committed
182
        Edge type list. ``etypes[i]`` stores the name of edge type i.
Minjie Wang's avatar
Minjie Wang committed
183
    node_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
184
185
186
        Node feature storage. If None, empty frame is created.
        Otherwise, ``node_frames[i]`` stores the node features
        of node type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
187
    edge_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
188
189
190
        Edge feature storage. If None, empty frame is created.
        Otherwise, ``edge_frames[i]`` stores the edge features
        of edge type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
191
    """
Da Zheng's avatar
Da Zheng committed
192
    # pylint: disable=unused-argument
Minjie Wang's avatar
Minjie Wang committed
193
194
195
196
197
    def __init__(self,
                 gidx,
                 ntypes,
                 etypes,
                 node_frames=None,
198
199
                 edge_frames=None):
        self._init(gidx, ntypes, etypes, node_frames, edge_frames)
Da Zheng's avatar
Da Zheng committed
200

201
202
    def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
        """Init internal states."""
Minjie Wang's avatar
Minjie Wang committed
203
        self._graph = gidx
204
        self._canonical_etypes = None
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        # Handle node types
        if isinstance(ntypes, tuple):
            if len(ntypes) != 2:
                errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format(
                    ntypes)
                raise TypeError(errmsg)
            if not is_unibipartite(self._graph.metagraph):
                raise ValueError('Invalid input. The metagraph must be a uni-directional'
                                 ' bipartite graph.')
            self._ntypes = ntypes[0] + ntypes[1]
            self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
            self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
            self._is_unibipartite = True
219
220
            if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:
                self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])]
221
222
223
224
225
226
227
228
229
230
231
        else:
            self._ntypes = ntypes
            src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
            self._is_unibipartite = (src_dst_map is not None)
            if self._is_unibipartite:
                self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
            else:
                self._srctypes_invmap = {t : i for i, t in enumerate(self._ntypes)}
                self._dsttypes_invmap = self._srctypes_invmap

        # Handle edge types
232
        self._etypes = etypes
233
234
235
        if self._canonical_etypes is None:
            self._canonical_etypes = make_canonical_etypes(
                self._etypes, self._ntypes, self._graph.metagraph)
236

Minjie Wang's avatar
Minjie Wang committed
237
        # An internal map from etype to canonical etype tuple.
238
239
        # If two etypes have the same name, an empty tuple is stored instead to indicate
        # ambiguity.
Minjie Wang's avatar
Minjie Wang committed
240
        self._etype2canonical = {}
241
        for i, ety in enumerate(self._etypes):
Minjie Wang's avatar
Minjie Wang committed
242
243
244
245
246
            if ety in self._etype2canonical:
                self._etype2canonical[ety] = tuple()
            else:
                self._etype2canonical[ety] = self._canonical_etypes[i]
        self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
Da Zheng's avatar
Da Zheng committed
247

248
249
250
        # Cached metagraph in networkx
        self._nx_metagraph = None

Minjie Wang's avatar
Minjie Wang committed
251
252
253
254
255
256
257
        # node and edge frame
        if node_frames is None:
            node_frames = [None] * len(self._ntypes)
        node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
                       if frame is None else frame
                       for i, frame in enumerate(node_frames)]
        self._node_frames = node_frames
Da Zheng's avatar
Da Zheng committed
258

Minjie Wang's avatar
Minjie Wang committed
259
260
261
262
263
264
        if edge_frames is None:
            edge_frames = [None] * len(self._etypes)
        edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
                       if frame is None else frame
                       for i, frame in enumerate(edge_frames)]
        self._edge_frames = edge_frames
Da Zheng's avatar
Da Zheng committed
265

Minjie Wang's avatar
Minjie Wang committed
266
267
268
269
270
271
272
        # message indicators
        self._msg_indices = [None] * len(self._etypes)
        self._msg_frames = []
        for i in range(len(self._etypes)):
            frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
            frame.set_initializer(init.zero_initializer)
            self._msg_frames.append(frame)
Da Zheng's avatar
Da Zheng committed
273

274
275
276
277
    def __getstate__(self):
        return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames

    def __setstate__(self, state):
278
279
280
281
282
283
284
285
286
287
288
289
290
        # Compatibility check
        # TODO: version the storage
        if isinstance(state, tuple) and len(state) == 5:
            # DGL 0.4.3+
            self._init(*state)
        elif isinstance(state, dict):
            # DGL 0.4.2-
            dgl_warning("The object is pickled with DGL version 0.4.2-.  "
                        "Some of the original attributes are ignored.")
            self._init(state['_graph'], state['_ntypes'], state['_etypes'], state['_node_frames'],
                       state['_edge_frames'])
        else:
            raise IOError("Unrecognized pickle format.")
Mufei Li's avatar
Mufei Li committed
291

Minjie Wang's avatar
Minjie Wang committed
292
    def _get_msg_index(self, etid):
293
        """Internal function for getting the message index array of the given edge type id."""
Minjie Wang's avatar
Minjie Wang committed
294
295
296
297
        if self._msg_indices[etid] is None:
            self._msg_indices[etid] = utils.zero_index(
                size=self._graph.number_of_edges(etid))
        return self._msg_indices[etid]
Da Zheng's avatar
Da Zheng committed
298

Minjie Wang's avatar
Minjie Wang committed
299
300
    def _set_msg_index(self, etid, index):
        self._msg_indices[etid] = index
Da Zheng's avatar
Da Zheng committed
301

Minjie Wang's avatar
Minjie Wang committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def __repr__(self):
        if len(self.ntypes) == 1 and len(self.etypes) == 1:
            ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
                   '      ndata_schemes={ndata}\n'
                   '      edata_schemes={edata})')
            return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
                              ndata=str(self.node_attr_schemes()),
                              edata=str(self.edge_attr_schemes()))
        else:
            ret = ('Graph(num_nodes={node},\n'
                   '      num_edges={edge},\n'
                   '      metagraph={meta})')
            nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
                          for i in range(len(self.ntypes))}
316
            nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
Minjie Wang's avatar
Minjie Wang committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                          for i in range(len(self.etypes))}
            meta = str(self.metagraph.edges())
            return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)

    #################################################################
    # Mutation operations
    #################################################################

    def add_nodes(self, num, data=None, ntype=None):
        """Add multiple new nodes of the same node type

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edge(self, u, v, data=None, etype=None):
        """Add an edge of ``etype`` between u of the source node type, and v
        of the destination node type..

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edges(self, u, v, data=None, etype=None):
        """Add multiple edges of ``etype`` between list of source nodes ``u``
        and list of destination nodes ``v`` of type ``vtype``.  A single edge
        is added between every pair of ``u[i]`` and ``v[i]``.

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    #################################################################
    # Metagraph query
    #################################################################
Da Zheng's avatar
Da Zheng committed
352

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    @property
    def is_unibipartite(self):
        """Return whether the graph is a uni-bipartite graph.

        A uni-bipartite heterograph can further divide its node types into two sets:
        SRC and DST. All edges are from nodes in SRC to nodes in DST. The following APIs
        can be used to get the nodes and types that belong to SRC and DST sets:

        * :func:`srctype` and :func:`dsttype`
        * :func:`srcdata` and :func:`dstdata`
        * :func:`srcnodes` and :func:`dstnodes`

        Note that we allow two node types to have the same name as long as one
        belongs to SRC while the other belongs to DST. To distinguish them, prepend
        the name with ``"SRC/"`` or ``"DST/"`` when specifying a node type.
        """
        return self._is_unibipartite

371
    @property
Minjie Wang's avatar
Minjie Wang committed
372
    def ntypes(self):
Mufei Li's avatar
Mufei Li committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.ntypes
        ['user', 'game']
        """
388
        return self._ntypes
Da Zheng's avatar
Da Zheng committed
389

390
    @property
Minjie Wang's avatar
Minjie Wang committed
391
    def etypes(self):
Mufei Li's avatar
Mufei Li committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.etypes
        ['follows', 'plays']
        """
407
        return self._etypes
Da Zheng's avatar
Da Zheng committed
408

Minjie Wang's avatar
Minjie Wang committed
409
410
411
412
413
    @property
    def canonical_etypes(self):
        """Return the list of canonical edge types of this graph.

        A canonical edge type is a tuple of string (src_type, edge_type, dst_type).
Mufei Li's avatar
Mufei Li committed
414
415
416
417
418
419
420
421
422
423
424
425
426

        Returns
        -------
        list of 3-tuples

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.canonical_etypes
        [('user', 'follows', 'user'), ('user', 'plays', 'game')]
Minjie Wang's avatar
Minjie Wang committed
427
428
429
        """
        return self._canonical_etypes

430
    @property
431
432
433
434
435
436
437
438
    def srctypes(self):
        """Return the node types in the SRC category. Return :attr:``ntypes`` if
        the graph is not a uni-bipartite graph.
        """
        if self.is_unibipartite:
            return sorted(list(self._srctypes_invmap.keys()))
        else:
            return self.ntypes
439
440

    @property
441
442
443
444
445
446
447
448
    def dsttypes(self):
        """Return the node types in the DST category. Return :attr:``ntypes`` if
        the graph is not a uni-bipartite graph.
        """
        if self.is_unibipartite:
            return sorted(list(self._dsttypes_invmap.keys()))
        else:
            return self.ntypes
449

Da Zheng's avatar
Da Zheng committed
450
451
    @property
    def metagraph(self):
452
453
454
455
        """Return the metagraph as networkx.MultiDiGraph.

        The nodes are labeled with node type names.
        The edges have their keys holding the edge type names.
Minjie Wang's avatar
Minjie Wang committed
456
457
458
459

        Returns
        -------
        networkx.MultiDiGraph
Mufei Li's avatar
Mufei Li committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> meta_g = g.metagraph

        The metagraph then has two nodes and two edges.

        >>> meta_g.nodes()
        NodeView(('user', 'game'))
        >>> meta_g.number_of_nodes()
        2
        >>> meta_g.edges()
        OutMultiEdgeDataView([('user', 'user'), ('user', 'game')])
        >>> meta_g.number_of_edges()
        2
Minjie Wang's avatar
Minjie Wang committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
        """
        if self._nx_metagraph is None:
            nx_graph = self._graph.metagraph.to_networkx()
            self._nx_metagraph = nx.MultiDiGraph()
            for u_v in nx_graph.edges:
                srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
                self._nx_metagraph.add_edge(srctype, dsttype, etype)
        return self._nx_metagraph

    def to_canonical_etype(self, etype):
        """Convert edge type to canonical etype: (srctype, etype, dsttype).

        The input can already be a canonical tuple.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type

        Returns
        -------
        tuple of str
Mufei Li's avatar
Mufei Li committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

        Examples
        --------

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g3 = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'follows', 'game')
        >>> g = dgl.hetero_from_relations([g1, g2, g3])

        Get canonical edge types.

        >>> g.to_canonical_etype('plays')
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype(('user', 'plays', 'game'))
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype('follows')
        DGLError: Edge type "follows" is ambiguous.
        Please use canonical etype type in the form of (srctype, etype, dsttype)
521
        """
Minjie Wang's avatar
Minjie Wang committed
522
523
        if isinstance(etype, tuple):
            return etype
524
        else:
Minjie Wang's avatar
Minjie Wang committed
525
526
527
528
529
530
531
532
533
534
535
536
537
            ret = self._etype2canonical.get(etype, None)
            if ret is None:
                raise DGLError('Edge type "{}" does not exist.'.format(etype))
            if len(ret) == 0:
                raise DGLError('Edge type "%s" is ambiguous. Please use canonical etype '
                               'type in the form of (srctype, etype, dsttype)' % etype)
            return ret

    def get_ntype_id(self, ntype):
        """Return the id of the given node type.

        ntype can also be None. If so, there should be only one node type in the
        graph.
538

Minjie Wang's avatar
Minjie Wang committed
539
540
541
542
        Parameters
        ----------
        ntype : str
            Node type
Da Zheng's avatar
Da Zheng committed
543
544
545

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
546
547
        int
        """
548
        if self.is_unibipartite and ntype is not None:
549
550
551
552
553
554
555
556
            # Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
            if ntype.startswith('SRC/'):
                return self.get_ntype_id_from_src(ntype[4:])
            elif ntype.startswith('DST/'):
                return self.get_ntype_id_from_dst(ntype[4:])
            # If there is no prefix, fallback to normal lookup.

        # Lookup both SRC and DST
Minjie Wang's avatar
Minjie Wang committed
557
        if ntype is None:
558
            if self.is_unibipartite or len(self._srctypes_invmap) != 1:
Minjie Wang's avatar
Minjie Wang committed
559
560
561
                raise DGLError('Node type name must be specified if there are more than one '
                               'node types.')
            return 0
562
        ntid = self._srctypes_invmap.get(ntype, self._dsttypes_invmap.get(ntype, None))
Minjie Wang's avatar
Minjie Wang committed
563
564
565
        if ntid is None:
            raise DGLError('Node type "{}" does not exist.'.format(ntype))
        return ntid
Da Zheng's avatar
Da Zheng committed
566

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def get_ntype_id_from_src(self, ntype):
        """Return the id of the given SRC node type.

        ntype can also be None. If so, there should be only one node type in the
        SRC category. Callable even when the self graph is not uni-bipartite.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._srctypes_invmap) != 1:
                raise DGLError('SRC node type name must be specified if there are more than one '
                               'SRC node types.')
586
            return next(iter(self._srctypes_invmap.values()))
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        ntid = self._srctypes_invmap.get(ntype, None)
        if ntid is None:
            raise DGLError('SRC node type "{}" does not exist.'.format(ntype))
        return ntid

    def get_ntype_id_from_dst(self, ntype):
        """Return the id of the given DST node type.

        ntype can also be None. If so, there should be only one node type in the
        DST category. Callable even when the self graph is not uni-bipartite.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._dsttypes_invmap) != 1:
                raise DGLError('DST node type name must be specified if there are more than one '
                               'DST node types.')
611
            return next(iter(self._dsttypes_invmap.values()))
612
613
614
615
616
        ntid = self._dsttypes_invmap.get(ntype, None)
        if ntid is None:
            raise DGLError('DST node type "{}" does not exist.'.format(ntype))
        return ntid

Minjie Wang's avatar
Minjie Wang committed
617
618
    def get_etype_id(self, etype):
        """Return the id of the given edge type.
619

Minjie Wang's avatar
Minjie Wang committed
620
621
622
623
624
625
626
        etype can also be None. If so, there should be only one edge type in the
        graph.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type
Da Zheng's avatar
Da Zheng committed
627

628
629
        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
630
631
632
633
634
635
636
637
638
639
640
        int
        """
        if etype is None:
            if self._graph.number_of_etypes() != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            return 0
        etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
        if etid is None:
            raise DGLError('Edge type "{}" does not exist.'.format(etype))
        return etid
Da Zheng's avatar
Da Zheng committed
641

Minjie Wang's avatar
Minjie Wang committed
642
643
644
    #################################################################
    # View
    #################################################################
Da Zheng's avatar
Da Zheng committed
645

646
    @property
Minjie Wang's avatar
Minjie Wang committed
647
    def nodes(self):
Mufei Li's avatar
Mufei Li committed
648
649
        """Return a node view that can be used to set/get feature
        data of a single node type.
Da Zheng's avatar
Da Zheng committed
650

Minjie Wang's avatar
Minjie Wang committed
651
652
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
653
654
655
656
657
        The following example uses PyTorch backend.

        To set features of all users

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
658
        >>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
Mufei Li's avatar
Mufei Li committed
659
660
661
662

        See Also
        --------
        ndata
663
        """
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        return HeteroNodeView(self, self.get_ntype_id)

    @property
    def srcnodes(self):
        """Return a SRC node view that can be used to set/get feature
        data of a single node type.

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all users

        >>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5)

        See Also
        --------
        srcdata
        """
        return HeteroNodeView(self, self.get_ntype_id_from_src)

    @property
    def dstnodes(self):
        """Return a DST node view that can be used to set/get feature
        data of a single node type.

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all games

        >>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5)

        See Also
        --------
        dstdata
        """
        return HeteroNodeView(self, self.get_ntype_id_from_dst)
Da Zheng's avatar
Da Zheng committed
705

706
    @property
Minjie Wang's avatar
Minjie Wang committed
707
708
    def ndata(self):
        """Return the data view of all the nodes.
Da Zheng's avatar
Da Zheng committed
709

Mufei Li's avatar
Mufei Li committed
710
        **Only works if the graph has one node type.**
Minjie Wang's avatar
Minjie Wang committed
711
712
713

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
714
715
716
717
718
719
720
721
722
723
724
        The following example uses PyTorch backend.

        To set features of all nodes in a heterogeneous graph
        with only one node type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.ndata['h'] = torch.zeros(3, 5)

        See Also
        --------
        nodes
Da Zheng's avatar
Da Zheng committed
725
        """
726
727
728
        ntid = self.get_ntype_id(None)
        ntype = self.ntypes[0]
        return HeteroNodeDataView(self, ntype, ntid, ALL)
Da Zheng's avatar
Da Zheng committed
729

730
731
    @property
    def srcdata(self):
732
        """Return the data view of all nodes in the SRC category.
733

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
734
735
736
737
738
739
        Only works if the graph is either

        * Uni-bipartite and has one node type in the SRC category.

        * Non-uni-bipartite and has only one node type (in this case identical to
        :any:`DGLHeteroGraph.ndata`)
740
741
742
743
744
745
746
747
748
749
750
751
752
753

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all source nodes in a graph with only one edge type:

        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.srcdata['h'] = torch.zeros(2, 5)

        This is equivalent to

        >>> g.nodes['user'].data['h'] = torch.zeros(2, 5)

754
755
756
757
758
759
760
761
762
763
        Also work on more complex uni-bipartite graph

        >>> g = dgl.heterograph({
        ...     ('user', 'plays', 'game'), [(0, 1), (1, 2)],
        ...     ('user', 'reads', 'book'), [(0, 1), (1, 0)],
        ...     })
        >>> print(g.is_unibipartite)
        True
        >>> g.srcdata['h'] = torch.zeros(2, 5)

764
765
766
767
768
769
770
771
        Notes
        -----
        This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.

        See Also
        --------
        nodes
        """
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
772
773
774
775
        err_msg = (
            'srcdata is only allowed when there is only one %s type.' %
            ('SRC' if self.is_unibipartite else 'node'))
        assert len(self.srctypes) == 1, err_msg
776
777
778
        ntype = self.srctypes[0]
        ntid = self.get_ntype_id_from_src(ntype)
        return HeteroNodeDataView(self, ntype, ntid, ALL)
779
780
781
782
783

    @property
    def dstdata(self):
        """Return the data view of all destination nodes.

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
784
785
786
787
788
789
        Only works if the graph is either

        * Uni-bipartite and has one node type in the SRC category.

        * Non-uni-bipartite and has only one node type (in this case identical to
        :any:`DGLHeteroGraph.ndata`)
790
791
792
793
794
795
796
797
798
799
800
801
802
803

        Examples
        --------
        The following example uses PyTorch backend.

        To set features of all source nodes in a graph with only one edge type:

        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.dstdata['h'] = torch.zeros(3, 5)

        This is equivalent to

        >>> g.nodes['game'].data['h'] = torch.zeros(3, 5)

804
805
806
807
808
809
810
811
812
813
        Also work on more complex uni-bipartite graph

        >>> g = dgl.heterograph({
        ...     ('user', 'plays', 'game'), [(0, 1), (1, 2)],
        ...     ('store', 'sells', 'game'), [(0, 1), (1, 0)],
        ...     })
        >>> print(g.is_unibipartite)
        True
        >>> g.dstdata['h'] = torch.zeros(3, 5)

814
815
816
817
818
819
820
821
        Notes
        -----
        This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.

        See Also
        --------
        nodes
        """
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
822
823
824
825
        err_msg = (
            'dstdata is only allowed when there is only one %s type.' %
            ('DST' if self.is_unibipartite else 'node'))
        assert len(self.dsttypes) == 1, err_msg
826
827
828
        ntype = self.dsttypes[0]
        ntid = self.get_ntype_id_from_dst(ntype)
        return HeteroNodeDataView(self, ntype, ntid, ALL)
829

830
    @property
Minjie Wang's avatar
Minjie Wang committed
831
    def edges(self):
Mufei Li's avatar
Mufei Li committed
832
833
        """Return an edge view that can be used to set/get feature
        data of a single edge type.
834

Minjie Wang's avatar
Minjie Wang committed
835
836
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
837
838
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
839
        To set features of all "play" relationships:
Mufei Li's avatar
Mufei Li committed
840
841
842
843
844
845
846

        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edges['plays'].data['h'] = torch.zeros(3, 4)

        See Also
        --------
        edata
847
        """
Minjie Wang's avatar
Minjie Wang committed
848
        return HeteroEdgeView(self)
849
850

    @property
Minjie Wang's avatar
Minjie Wang committed
851
852
    def edata(self):
        """Return the data view of all the edges.
853

Mufei Li's avatar
Mufei Li committed
854
        **Only works if the graph has one edge type.**
Minjie Wang's avatar
Minjie Wang committed
855
856
857

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
858
859
860
861
862
863
        The following example uses PyTorch backend.

        To set features of all edges in a heterogeneous graph
        with only one edge type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
864
        >>> g.edata['h'] = torch.zeros(2, 5)
Mufei Li's avatar
Mufei Li committed
865
866
867
868

        See Also
        --------
        edges
869
        """
Minjie Wang's avatar
Minjie Wang committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
        return HeteroEdgeDataView(self, None, ALL)

    def _find_etypes(self, key):
        etypes = [
            i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
            (key[0] == SLICE_FULL or key[0] == srctype) and
            (key[1] == SLICE_FULL or key[1] == etype) and
            (key[2] == SLICE_FULL or key[2] == dsttype)]
        return etypes

    def __getitem__(self, key):
        """Return the relation slice of this graph.

        A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
        ``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
        slice (``:``) representing wildcard (i.e. any source/edge/destination type).

        A relation slice is a homogeneous (with one node type and one edge type) or
        bipartite (with two node types and one edge type) graph, transformed from
        the original heterogeneous graph.

        If there is only one canonical edge type found, then the returned relation
        slice would be a subgraph induced from the original graph.  That is, it is
        equivalent to ``self.edge_type_subgraph(etype)``.  The node and edge features
        of the returned graph would be shared with thew original graph.

        If there are multiple canonical edge type found, then the source/edge/destination
        node types would be a *concatenation* of original node/edge types.  The
        new source/destination node type would have the concatenation determined by
        :func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
        types as its name.  The source/destination node would be formed by concatenating the
        common features of the original source/destination types, therefore they are not
        shared with the original graph.  Edge type is similar.
        """
        err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
                  "to get view of one relation type. Use : to slice multiple types (e.g. " +\
                  "G['srctype', :, 'dsttype'])."

908
        orig_key = key
Minjie Wang's avatar
Minjie Wang committed
909
910
911
912
913
914
915
        if not isinstance(key, tuple):
            key = (SLICE_FULL, key, SLICE_FULL)

        if len(key) != 3:
            raise DGLError(err_msg)

        etypes = self._find_etypes(key)
916
917
918
919

        if len(etypes) == 0:
            raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))

Minjie Wang's avatar
Minjie Wang committed
920
921
922
        if len(etypes) == 1:
            # no ambiguity: return the unitgraph itself
            srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
923
            stid = self.get_ntype_id_from_src(srctype)
Minjie Wang's avatar
Minjie Wang committed
924
            etid = self.get_etype_id((srctype, etype, dsttype))
925
            dtid = self.get_ntype_id_from_dst(dsttype)
Minjie Wang's avatar
Minjie Wang committed
926
927
928
929
930
931
            new_g = self._graph.get_relation_graph(etid)

            if stid == dtid:
                new_ntypes = [srctype]
                new_nframes = [self._node_frames[stid]]
            else:
932
                new_ntypes = ([srctype], [dsttype])
Minjie Wang's avatar
Minjie Wang committed
933
934
935
                new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
            new_etypes = [etype]
            new_eframes = [self._edge_frames[etid]]
936

Minjie Wang's avatar
Minjie Wang committed
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
            return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
        else:
            flat = self._graph.flatten_relations(etypes)
            new_g = flat.graph

            # merge frames
            stids = flat.induced_srctype_set.asnumpy()
            dtids = flat.induced_dsttype_set.asnumpy()
            etids = flat.induced_etype_set.asnumpy()
            new_ntypes = [combine_names(self.ntypes, stids)]
            if new_g.number_of_ntypes() == 2:
                new_ntypes.append(combine_names(self.ntypes, dtids))
                new_nframes = [
                    combine_frames(self._node_frames, stids),
                    combine_frames(self._node_frames, dtids)]
            else:
                assert np.array_equal(stids, dtids)
                new_nframes = [combine_frames(self._node_frames, stids)]
            new_etypes = [combine_names(self.etypes, etids)]
            new_eframes = [combine_frames(self._edge_frames, etids)]

            # create new heterograph
            new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)

            src = new_ntypes[0]
            dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
            # put the parent node/edge type and IDs
            new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
            new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
            new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
            new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
            new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
            new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)

            return new_hg

    #################################################################
    # Graph query
    #################################################################

    def number_of_nodes(self, ntype=None):
978
        """Return the number of nodes of the given type in the heterograph.
Da Zheng's avatar
Da Zheng committed
979
980
981

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
982
983
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
984
            in the graph. (Default: None)
985
986
987
988
989

        Returns
        -------
        int
            The number of nodes
Da Zheng's avatar
Da Zheng committed
990
991
992

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
993
994
995
996
997

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_nodes('user')
        3
        >>> g.number_of_nodes()
998
        3
Da Zheng's avatar
Da Zheng committed
999
        """
Minjie Wang's avatar
Minjie Wang committed
1000
        return self._graph.number_of_nodes(self.get_ntype_id(ntype))
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    def number_of_src_nodes(self, ntype=None):
        """Return the number of nodes of the given SRC node type in the heterograph.

        The heterograph is usually a unidirectional bipartite graph.

        Parameters
        ----------
        ntype : str, optional
            Node type.
            If omitted, there should be only one node type in the SRC category.

        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.number_of_src_nodes('user')
        2
        >>> g.number_of_src_nodes()
        2
        >>> g.number_of_nodes('user')
        2
        """
        return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype))

    def number_of_dst_nodes(self, ntype=None):
        """Return the number of nodes of the given DST node type in the heterograph.

        The heterograph is usually a unidirectional bipartite graph.

        Parameters
        ----------
        ntype : str, optional
            Node type.
            If omitted, there should be only one node type in the DST category.

        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.number_of_dst_nodes('game')
        3
        >>> g.number_of_dst_nodes()
        3
        >>> g.number_of_nodes('game')
        3
        """
        return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype))

Minjie Wang's avatar
Minjie Wang committed
1058
    def number_of_edges(self, etype=None):
1059
1060
1061
1062
        """Return the number of edges of the given type in the heterograph.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
1063
1064
1065
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1066
1067
1068
1069
1070
1071

        Returns
        -------
        int
            The number of edges

1072
1073
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1074
1075
1076
1077
1078
1079
1080

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_edges(('user', 'follows', 'user'))
        >>> g.number_of_edges('follows')
        2
        >>> g.number_of_edges()
        2
1081
        """
Minjie Wang's avatar
Minjie Wang committed
1082
1083
1084
1085
        return self._graph.number_of_edges(self.get_etype_id(etype))

    @property
    def is_multigraph(self):
Mufei Li's avatar
Mufei Li committed
1086
1087
1088
1089
1090
1091
1092
        """Whether the graph is a multigraph

        Returns
        -------
        bool
            True if the graph is a multigraph, False otherwise.
        """
1093
        return self._graph.is_multigraph()
Minjie Wang's avatar
Minjie Wang committed
1094
1095
1096

    @property
    def is_readonly(self):
Mufei Li's avatar
Mufei Li committed
1097
1098
1099
1100
1101
1102
1103
        """Whether the graph is readonly

        Returns
        -------
        bool
            True if the graph is readonly, False otherwise.
        """
Minjie Wang's avatar
Minjie Wang committed
1104
        return self._graph.is_readonly()
Da Zheng's avatar
Da Zheng committed
1105

1106
1107
1108
1109
1110
1111
1112
1113
    @property
    def idtype(self):
        """The dtype of graph index

        Returns
        -------
        backend dtype object
            th.int32/th.int64 or tf.int32/tf.int64 etc.
1114
1115
1116
1117
1118

        See Also
        --------
        long
        int
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
        """
        return getattr(F, self._graph.dtype)

    @property
    def _idtype_str(self):
        """The dtype of graph index

        Returns
        -------
        backend dtype object
            th.int32/th.int64 or tf.int32/tf.int64 etc.
        """
        return self._graph.dtype

Minjie Wang's avatar
Minjie Wang committed
1133
    def has_node(self, vid, ntype=None):
Mufei Li's avatar
Mufei Li committed
1134
        """Whether the graph has a node with a particular id and type.
Da Zheng's avatar
Da Zheng committed
1135
1136
1137
1138
1139

        Parameters
        ----------
        vid : int
            The node ID.
Minjie Wang's avatar
Minjie Wang committed
1140
1141
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
1142
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1143
1144
1145
1146

        Returns
        -------
        bool
Mufei Li's avatar
Mufei Li committed
1147
            True if the node exists, False otherwise
Da Zheng's avatar
Da Zheng committed
1148
1149
1150

        Examples
        --------
Minjie Wang's avatar
Minjie Wang committed
1151
        >>> g.has_node(0, 'user')
Da Zheng's avatar
Da Zheng committed
1152
        True
Minjie Wang's avatar
Minjie Wang committed
1153
        >>> g.has_node(4, 'user')
Da Zheng's avatar
Da Zheng committed
1154
1155
1156
1157
1158
1159
        False

        See Also
        --------
        has_nodes
        """
Minjie Wang's avatar
Minjie Wang committed
1160
        return self._graph.has_node(self.get_ntype_id(ntype), vid)
Da Zheng's avatar
Da Zheng committed
1161

Minjie Wang's avatar
Minjie Wang committed
1162
    def has_nodes(self, vids, ntype=None):
Mufei Li's avatar
Mufei Li committed
1163
        """Whether the graph has nodes with ids and a particular type.
Da Zheng's avatar
Da Zheng committed
1164
1165
1166
1167
1168

        Parameters
        ----------
        vid : list or tensor
            The array of node IDs.
Minjie Wang's avatar
Minjie Wang committed
1169
1170
1171
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1172
1173
1174
1175

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
1176
1177
            Binary tensor indicating the existence of nodes with the specified ids and type.
            ``a[i]=1`` if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
1178
1179
1180
1181
1182

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
1183
        >>> g.has_nodes([0, 1, 2, 3, 4], 'user')
Da Zheng's avatar
Da Zheng committed
1184
1185
1186
1187
1188
1189
        tensor([1, 1, 1, 0, 0])

        See Also
        --------
        has_node
        """
1190
        vids = utils.toindex(vids, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1191
        rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
1192
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1193

Minjie Wang's avatar
Minjie Wang committed
1194
    def has_edge_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1195
        """Whether the graph has an edge (u, v) of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1196
1197
1198
1199
1200
1201
1202

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1203
1204
1205
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1206
1207
1208
1209
1210
1211
1212
1213

        Returns
        -------
        bool
            True if the edge is in the graph, False otherwise.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1214

Minjie Wang's avatar
Minjie Wang committed
1215
        >>> g.has_edge_between(0, 1, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1216
        True
Minjie Wang's avatar
Minjie Wang committed
1217
        >>> g.has_edge_between(0, 2, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1218
1219
1220
1221
1222
1223
        False

        See Also
        --------
        has_edges_between
        """
Minjie Wang's avatar
Minjie Wang committed
1224
        return self._graph.has_edge_between(self.get_etype_id(etype), u, v)
Da Zheng's avatar
Da Zheng committed
1225

Minjie Wang's avatar
Minjie Wang committed
1226
    def has_edges_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1227
        """Whether the graph has edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1228
1229
1230
1231
1232
1233
1234

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
1235
1236
1237
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1238
1239
1240
1241

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
1242
1243
            Binary tensor indicating the existence of edges. ``a[i]=1`` if the graph
            contains edge ``(u[i], v[i])`` of type ``etype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
1244
1245
1246
1247
1248

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
1249
        >>> g.has_edges_between([0, 0], [1, 2], ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
1250
1251
1252
1253
1254
1255
        tensor([1, 0])

        See Also
        --------
        has_edge_between
        """
1256
1257
        u = utils.toindex(u, self._idtype_str)
        v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1258
        rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
1259
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1260

Minjie Wang's avatar
Minjie Wang committed
1261
    def predecessors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1262
        """Return the predecessors of node `v` in the graph with the specified
Da Zheng's avatar
Da Zheng committed
1263
1264
        edge type.

Mufei Li's avatar
Mufei Li committed
1265
1266
        Node `u` is a predecessor of `v` if an edge `(u, v)` with type `etype`
        exists in the graph.
Da Zheng's avatar
Da Zheng committed
1267
1268
1269
1270

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
1271
            The destination node.
Minjie Wang's avatar
Minjie Wang committed
1272
1273
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1274
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1275
1276
1277
1278

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
1279
            Array of predecessor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1280
1281
1282
1283
1284

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1285
1286
1287
1288
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([plays_g, devs_g])
        >>> g.predecessors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1289
        tensor([0, 1])
Mufei Li's avatar
Mufei Li committed
1290
1291
        >>> g.predecessors(0, 'develops')
        tensor([0])
Da Zheng's avatar
Da Zheng committed
1292
1293
1294
1295
1296

        See Also
        --------
        successors
        """
Minjie Wang's avatar
Minjie Wang committed
1297
        return self._graph.predecessors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1298

Minjie Wang's avatar
Minjie Wang committed
1299
    def successors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1300
        """Return the successors of node `v` in the graph with the specified edge
Da Zheng's avatar
Da Zheng committed
1301
1302
        type.

Mufei Li's avatar
Mufei Li committed
1303
1304
        Node `u` is a successor of `v` if an edge `(v, u)` with type `etype` exists
        in the graph.
Da Zheng's avatar
Da Zheng committed
1305
1306
1307
1308

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
1309
            The source node.
Minjie Wang's avatar
Minjie Wang committed
1310
1311
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1312
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1313
1314
1315
1316

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
1317
            Array of successor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1318
1319
1320
1321
1322

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1323
1324
1325
1326
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> g.successors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1327
        tensor([0])
Mufei Li's avatar
Mufei Li committed
1328
1329
        >>> g.successors(0, 'follows')
        tensor([1])
Da Zheng's avatar
Da Zheng committed
1330
1331
1332
1333
1334

        See Also
        --------
        predecessors
        """
1335
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
1336
        return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1337

1338
    def edge_id(self, u, v, force_multi=None, return_array=False, etype=None):
Da Zheng's avatar
Da Zheng committed
1339
        """Return the edge ID, or an array of edge IDs, between source node
Mufei Li's avatar
Mufei Li committed
1340
        `u` and destination node `v`, with the specified edge type
Da Zheng's avatar
Da Zheng committed
1341
1342
1343
1344
1345
1346
1347

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1348
        force_multi : bool, optional
1349
1350
1351
1352
1353
            Deprecated (Will be deleted in the future).
            If False, will return a single edge ID.
            If True, will always return an array. (Default: False)
        return_array : bool, optional
            If False, will return a single edge ID.
Minjie Wang's avatar
Minjie Wang committed
1354
1355
1356
1357
            If True, will always return an array. (Default: False)
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1358
1359
1360
1361

        Returns
        -------
        int or tensor
1362
            The edge ID if ``return_array == False``.
Da Zheng's avatar
Da Zheng committed
1363
1364
            The edge ID array otherwise.

1365
1366
1367
1368
1369
        Notes
        -----
        If multiply edges exist between `u` and `v` and return_array is False,
        the result is undefined.

Da Zheng's avatar
Da Zheng committed
1370
1371
1372
1373
        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge id.

        >>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game'))
        2
1384
        >>> g.edge_id(1, 2, return_array=True, etype=('user', 'follows', 'user'))
Mufei Li's avatar
Mufei Li committed
1385
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
1386
1387
1388
1389
1390

        See Also
        --------
        edge_ids
        """
Minjie Wang's avatar
Minjie Wang committed
1391
        idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        if force_multi is not None:
            dgl_warning("force_multi will be deprecated." \
                        "Please use return_array instead")
            return_array = force_multi

        if return_array:
            return idx.tousertensor()
        else:
            assert len(idx) == 1, "For return_array=False, there should be one and " \
                "only one edge between u and v, but get {} edges. " \
                "Please use return_array=True instead".format(len(idx))
            return idx[0]
Da Zheng's avatar
Da Zheng committed
1404

1405
    def edge_ids(self, u, v, force_multi=None, return_uv=False, etype=None):
Da Zheng's avatar
Da Zheng committed
1406
        """Return all edge IDs between source node array `u` and destination
Mufei Li's avatar
Mufei Li committed
1407
        node array `v` with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1408
1409
1410
1411
1412
1413
1414

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
1415
        force_multi : bool, optional
1416
            Deprecated (Will be deleted in the future).
Mufei Li's avatar
Mufei Li committed
1417
1418
            Whether to always treat the graph as a multigraph. See the
            "Returns" for their effects. (Default: False)
1419
1420
        return_uv : bool
            See the "Returns" for their effects. (Default: False)
Minjie Wang's avatar
Minjie Wang committed
1421
1422
1423
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1424
1425
1426
1427

        Returns
        -------
        tensor, or (tensor, tensor, tensor)
Mufei Li's avatar
Mufei Li committed
1428

1429
1430
            * If ``return_uv=False``, return a single edge ID array ``e``.
            ``e[i]`` is the edge ID between ``u[i]`` and ``v[i]``.
Mufei Li's avatar
Mufei Li committed
1431
1432
1433
1434

            * Otherwise, return three arrays ``(eu, ev, e)``.  ``e[i]`` is the ID
            of an edge between ``eu[i]`` and ``ev[i]``.  All edges between ``u[i]``
            and ``v[i]`` are returned.
Da Zheng's avatar
Da Zheng committed
1435
1436
1437

        Notes
        -----
1438
        If the graph is a simple graph, ``return_uv=False``, and no edge
Mufei Li's avatar
Mufei Li committed
1439
1440
        exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined
        and an empty tensor is returned.
Da Zheng's avatar
Da Zheng committed
1441

1442
1443
1444
        If the graph is a multi graph, ``return_uv=False``, and multi edges
        exist between some pairs of `u[i]` and `v[i]`, the result is undefined.

Da Zheng's avatar
Da Zheng committed
1445
1446
1447
1448
        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge ids.

        >>> plays_g.edge_ids([0], [2], etype=('user', 'plays', 'game'))
        tensor([], dtype=torch.int64)
        >>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game'))
        tensor([2])
1461
        >>> g.edge_ids([1], [2], return_uv=True, etype=('user', 'follows', 'user'))
Mufei Li's avatar
Mufei Li committed
1462
        (tensor([1, 1]), tensor([2, 2]), tensor([1, 2]))
Da Zheng's avatar
Da Zheng committed
1463
1464
1465
1466
1467

        See Also
        --------
        edge_id
        """
1468
1469
1470
1471
        check_same_dtype(self._idtype_str, u)
        check_same_dtype(self._idtype_str, v)
        u = utils.toindex(u, self._idtype_str)
        v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1472
        src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
1473
1474
1475
1476
1477
1478
        if force_multi is not None:
            dgl_warning("force_multi will be deprecated, " \
                        "Please use return_uv instead")
            return_uv = force_multi

        if return_uv:
1479
1480
            return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
        else:
1481
1482
1483
            assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
                "only one edge between each u and v, expect {} edges but get {}. " \
                "Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
1484
            return eid.tousertensor()
Da Zheng's avatar
Da Zheng committed
1485

Minjie Wang's avatar
Minjie Wang committed
1486
    def find_edges(self, eid, etype=None):
Mufei Li's avatar
Mufei Li committed
1487
1488
1489
        """Given an edge ID array with the specified type, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.
Da Zheng's avatar
Da Zheng committed
1490
1491
1492
1493
1494

        Parameters
        ----------
        eid : list, tensor
            The edge ID array.
Minjie Wang's avatar
Minjie Wang committed
1495
1496
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1497
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1510
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
1511
        >>> g.find_edges([0, 2], ('user', 'plays', 'game'))
Mufei Li's avatar
Mufei Li committed
1512
1513
1514
        (tensor([0, 1]), tensor([0, 2]))
        >>> g.find_edges([0, 2])
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1515
        """
1516
1517
        check_same_dtype(self._idtype_str, eid)
        eid = utils.toindex(eid, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1518
        src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
1519
        return src.tousertensor(), dst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1520

Minjie Wang's avatar
Minjie Wang committed
1521
    def in_edges(self, v, form='uv', etype=None):
Mufei Li's avatar
Mufei Li committed
1522
        """Return the inbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1523
1524
1525
1526

        Parameters
        ----------
        v : int, list, tensor
Mufei Li's avatar
Mufei Li committed
1527
            The node id(s) of destination type.
Da Zheng's avatar
Da Zheng committed
1528
1529
1530
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1531
1532
1533
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1534
1535
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1536
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1537
1538
1539

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1540
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
Da Zheng's avatar
Da Zheng committed
1541
            All inbound edges to ``v`` are returned.
Mufei Li's avatar
Mufei Li committed
1542
1543
1544
1545
1546
1547
1548
1549

            * If ``form='eid'``, return a tensor for the ids of the
              inbound edges of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors
              ``(eu, ev, eid)``. ``eid[i]`` gives the ID of the
              edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``eu[i]`` is the source node of an edge to ``ev[i]``.
Da Zheng's avatar
Da Zheng committed
1550
1551
1552
1553
1554

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1555
1556
1557
1558
1559
1560
1561
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.in_edges([0, 2], form='eid')
        tensor([0, 2])
        >>> g.in_edges([0, 2], form='all')
        (tensor([0, 1]), tensor([0, 2]), tensor([0, 2]))
        >>> g.in_edges([0, 2], form='uv')
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1562
        """
1563
1564
        check_same_dtype(self._idtype_str, v)
        v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1565
        src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
1566
1567
1568
1569
1570
1571
1572
1573
1574
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Mufei Li's avatar
Mufei Li committed
1575
1576
    def out_edges(self, u, form='uv', etype=None):
        """Return the outbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1577
1578
1579

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1580
1581
        u : int, list, tensor
            The node id(s) of source type.
Da Zheng's avatar
Da Zheng committed
1582
1583
1584
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1585
1586
1587
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1588
1589
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1590
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1591
1592
1593

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1594
1595
1596
1597
1598
1599
1600
1601
1602
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
            All outbound edges from ``u`` are returned.

            * If ``form='eid'``, return a tensor for the ids of the outbound edges
              of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1603
1604
1605
1606

        Examples
        --------

Mufei Li's avatar
Mufei Li committed
1607
1608
1609
1610
1611
1612
1613
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.out_edges([0, 1], form='eid')
        tensor([0, 1, 2])
        >>> g.out_edges([0, 1], form='all')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([0, 1, 2]))
        >>> g.out_edges([0, 1], form='uv')
        (tensor([0, 1, 1]), tensor([0, 1, 2]))
Da Zheng's avatar
Da Zheng committed
1614
        """
1615
1616
        check_same_dtype(self._idtype_str, u)
        u = utils.toindex(u, self._idtype_str)
Mufei Li's avatar
Mufei Li committed
1617
        src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
1618
1619
1620
1621
1622
1623
1624
1625
1626
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1627
    def all_edges(self, form='uv', order=None, etype=None):
Mufei Li's avatar
Mufei Li committed
1628
        """Return all edges with the specified type.
Da Zheng's avatar
Da Zheng committed
1629
1630
1631
1632
1633
1634

        Parameters
        ----------
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1635
1636
1637
1638
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
        order : str or None
Da Zheng's avatar
Da Zheng committed
1639
1640
            The order of the returned edges. Currently support:

Mufei Li's avatar
Mufei Li committed
1641
1642
1643
            - ``'srcdst'`` : sorted by their src and dst ids.
            - ``'eid'``    : sorted by edge Ids.
            - ``None``     : arbitrary order, default
Minjie Wang's avatar
Minjie Wang committed
1644
1645
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1646
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1647
1648
1649

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1650
1651
1652
1653
1654
1655
1656
1657
        tensor or (tensor, tensor, tensor) or (tensor, tensor)

            * If ``form='eid'``, return a tensor for the ids of all edges
              with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1658
1659
1660
1661
1662

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1663
1664
1665
1666
1667
1668
1669
        >>> g = dgl.bipartite([(1, 1), (0, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.all_edges(form='eid', order='srcdst')
        tensor([1, 0, 2])
        >>> g.all_edges(form='all', order='srcdst')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([1, 0, 2]))
        >>> g.all_edges(form='uv', order='eid')
        (tensor([1, 0, 1]), tensor([1, 0, 2]))
Da Zheng's avatar
Da Zheng committed
1670
        """
Minjie Wang's avatar
Minjie Wang committed
1671
        src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
1672
1673
1674
1675
1676
1677
1678
1679
1680
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1681
    def in_degree(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1682
        """Return the in-degree of node ``v`` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1683
1684
1685
1686
1687

        Parameters
        ----------
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1688
1689
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1690
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1691
1692
1693
1694
1695
1696
1697
1698

        Returns
        -------
        int
            The in-degree.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1709
        2
Mufei Li's avatar
Mufei Li committed
1710
1711
        >>> g.in_degree(0, 'follows')
        0
Da Zheng's avatar
Da Zheng committed
1712
1713
1714
1715
1716

        See Also
        --------
        in_degrees
        """
Minjie Wang's avatar
Minjie Wang committed
1717
        return self._graph.in_degree(self.get_etype_id(etype), v)
Da Zheng's avatar
Da Zheng committed
1718

Minjie Wang's avatar
Minjie Wang committed
1719
    def in_degrees(self, v=ALL, etype=None):
Mufei Li's avatar
Mufei Li committed
1720
        """Return the in-degrees of nodes v with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1721
1722
1723
1724

        Parameters
        ----------
        v : list, tensor, optional.
Mufei Li's avatar
Mufei Li committed
1725
1726
1727
            The node ID array of the destination type. Default is to return the
            degrees of all nodes.
        etype : str or tuple of str or None, optional
Minjie Wang's avatar
Minjie Wang committed
1728
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1729
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1730
1731
1732
1733

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1734
1735
            The in-degree array. ``d[i]`` gives the in-degree of node ``v[i]``
            with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1736
1737
1738
1739
1740

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degrees(0, 'plays')
        tensor([2])
        >>> g.in_degrees(etype='follows')
        tensor([0, 1, 2])
Da Zheng's avatar
Da Zheng committed
1753
1754
1755
1756
1757

        See Also
        --------
        in_degree
        """
1758
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
1759
1760
        etid = self.get_etype_id(etype)
        _, dtid = self._graph.metagraph.find_edge(etid)
1761
        if is_all(v):
1762
            v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)), self._idtype_str)
1763
        else:
1764
            v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
1765
        return self._graph.in_degrees(etid, v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1766

Mufei Li's avatar
Mufei Li committed
1767
1768
    def out_degree(self, u, etype=None):
        """Return the out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1769
1770
1771

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1772
        u : int
Da Zheng's avatar
Da Zheng committed
1773
            The node ID of source type.
Minjie Wang's avatar
Minjie Wang committed
1774
1775
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1776
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1777
1778
1779
1780

        Returns
        -------
        int
Mufei Li's avatar
Mufei Li committed
1781
            The out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1782
1783
1784

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1795
        1
Mufei Li's avatar
Mufei Li committed
1796
1797
        >>> g.out_degree(1, 'follows')
        2
Da Zheng's avatar
Da Zheng committed
1798

1799
1800
1801
1802
        See Also
        --------
        out_degrees
        """
Mufei Li's avatar
Mufei Li committed
1803
        return self._graph.out_degree(self.get_etype_id(etype), u)
1804

Mufei Li's avatar
Mufei Li committed
1805
1806
    def out_degrees(self, u=ALL, etype=None):
        """Return the out-degrees of nodes u with edges of type ``etype``.
1807
1808
1809

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1810
        u : list, tensor
1811
1812
            The node ID array of source type. Default is to return the degrees
            of all the nodes.
Minjie Wang's avatar
Minjie Wang committed
1813
1814
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1815
            in the graph. (Default: None)
1816
1817
1818
1819

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1820
1821
            The out-degree array. ``d[i]`` gives the out-degree of node ``u[i]``
            with edges of type ``etype``.
1822
1823
1824
1825
1826

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degrees(0, 'plays')
        tensor([1])
        >>> g.out_degrees(etype='follows')
        tensor([1, 2, 0])
1839
1840
1841
1842
1843

        See Also
        --------
        out_degree
        """
1844
        check_same_dtype(self._idtype_str, u)
Minjie Wang's avatar
Minjie Wang committed
1845
1846
        etid = self.get_etype_id(etype)
        stid, _ = self._graph.metagraph.find_edge(etid)
Mufei Li's avatar
Mufei Li committed
1847
        if is_all(u):
1848
            u = utils.toindex(slice(0, self._graph.number_of_nodes(stid)), self._idtype_str)
1849
        else:
1850
            u = utils.toindex(u, self._idtype_str)
Mufei Li's avatar
Mufei Li committed
1851
        return self._graph.out_degrees(etid, u).tousertensor()
Minjie Wang's avatar
Minjie Wang committed
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873

    def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
        """Internal function to create a subgraph."""
        node_frames = [
            FrameRef(Frame(
                self._node_frames[i][induced_nodes_of_ntype],
                num_rows=len(induced_nodes_of_ntype)))
            for i, induced_nodes_of_ntype in enumerate(induced_nodes)]
        edge_frames = [
            FrameRef(Frame(
                self._edge_frames[i][induced_edges_of_etype],
                num_rows=len(induced_edges_of_etype)))
            for i, induced_edges_of_etype in enumerate(induced_edges)]

        hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
        hsg.is_subgraph = True
        for ntype, induced_nid in zip(self.ntypes, induced_nodes):
            hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
        for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
            hsg.edges[etype].data[EID] = induced_eid.tousertensor()

        return hsg
1874

Minjie Wang's avatar
Minjie Wang committed
1875
1876
    def subgraph(self, nodes):
        """Return the subgraph induced on given nodes.
1877

Minjie Wang's avatar
Minjie Wang committed
1878
1879
        The metagraph of the returned subgraph is the same as the parent graph.
        Features are copied from the original graph.
1880

Minjie Wang's avatar
Minjie Wang committed
1881
1882
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1883
1884
1885
        nodes : dict[str->list or iterable]
            A dictionary mapping node types to node ID array for constructing
            subgraph. All nodes must exist in the graph.
1886

Minjie Wang's avatar
Minjie Wang committed
1887
1888
1889
1890
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1891
1892
1893
1894

            The nodes and edges in the subgraph are relabeled using consecutive
            integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1895
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1896
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1897
            subgraph.
Mufei Li's avatar
Mufei Li committed
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.subgraph({'user': [4, 5]})
        An error occurs as these nodes do not exist.
        >>> sub_g = g.subgraph({'user': [1, 2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 0},
1918
              num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['follows'].edata[dgl.EID] # Get the edge indices in the raw graph
        tensor([1, 2])

        Get the copied node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        edge_subgraph
Minjie Wang's avatar
Minjie Wang committed
1942
        """
1943
1944
1945
        check_same_dtype(self._idtype_str, nodes)
        induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str)
                         for ntype in self.ntypes]
Minjie Wang's avatar
Minjie Wang committed
1946
1947
        sgi = self._graph.node_subgraph(induced_nodes)
        induced_edges = sgi.induced_edges
1948

Minjie Wang's avatar
Minjie Wang committed
1949
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
1950

Minjie Wang's avatar
Minjie Wang committed
1951
1952
    def edge_subgraph(self, edges, preserve_nodes=False):
        """Return the subgraph induced on given edges.
1953

Minjie Wang's avatar
Minjie Wang committed
1954
        The metagraph of the returned subgraph is the same as the parent graph.
1955

Minjie Wang's avatar
Minjie Wang committed
1956
        Features are copied from the original graph.
1957

Minjie Wang's avatar
Minjie Wang committed
1958
1959
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1960
1961
1962
1963
1964
1965
1966
1967
1968
        edges : dict[str->list or iterable]
            A dictionary mapping edge types to edge ID array for constructing
            subgraph. All edges must exist in the subgraph.

            The edge types are characterized by triplets of
            ``(src type, etype, dst type)``.
        preserve_nodes : bool
            Whether to preserve all nodes or not. If false, all nodes
            without edges will be removed. (Default: False)
1969

Minjie Wang's avatar
Minjie Wang committed
1970
1971
1972
1973
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1974
1975
1976

            The nodes and edges are relabeled using consecutive integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1977
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1978
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1979
            subgraph.
Mufei Li's avatar
Mufei Li committed
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.edge_subgraph({('user', 'follows', 'user'): [5, 6]})
        An error occurs as these edges do not exist.
        >>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): [1, 2],
        >>>                          ('user', 'plays', 'game'): [2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 1},
2001
              num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['plays'].edata[dgl.EID]   # Get the edge indices in the raw graph
        tensor([2])

        Get the copied node features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        subgraph
Minjie Wang's avatar
Minjie Wang committed
2025
        """
2026
        check_idtype_dict(self._idtype_str, edges)
Minjie Wang's avatar
Minjie Wang committed
2027
2028
        edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
        induced_edges = [
2029
            utils.toindex(edges.get(canonical_etype, []), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2030
2031
2032
            for canonical_etype in self.canonical_etypes]
        sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
        induced_nodes = sgi.induced_nodes
2033

Minjie Wang's avatar
Minjie Wang committed
2034
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
2035

Minjie Wang's avatar
Minjie Wang committed
2036
2037
    def node_type_subgraph(self, ntypes):
        """Return the subgraph induced on given node types.
2038

Mufei Li's avatar
Mufei Li committed
2039
2040
        The metagraph of the returned subgraph is the subgraph of the original
        metagraph induced from the node types.
2041

Minjie Wang's avatar
Minjie Wang committed
2042
        Features are shared with the original graph.
2043

Minjie Wang's avatar
Minjie Wang committed
2044
2045
2046
2047
        Parameters
        ----------
        ntypes : list[str]
            The node types
2048

Minjie Wang's avatar
Minjie Wang committed
2049
2050
2051
2052
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.node_type_subgraph(['user'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}
              edata_schemes={})

        Get the shared node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        edge_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
2089
2090
2091
2092
2093
2094
        """
        rel_graphs = []
        meta_edges = []
        induced_etypes = []
        node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
        edge_frames = []
2095

2096
        num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
Minjie Wang's avatar
Minjie Wang committed
2097
2098
        ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
        srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
2099
        for i in range(len(self._etypes)):
Minjie Wang's avatar
Minjie Wang committed
2100
2101
            srctype = self._ntypes[srctype_id[i]]
            dsttype = self._ntypes[dsttype_id[i]]
2102

Minjie Wang's avatar
Minjie Wang committed
2103
2104
2105
2106
2107
            if srctype in ntypes and dsttype in ntypes:
                meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
                rel_graphs.append(self._graph.get_relation_graph(i))
                induced_etypes.append(self.etypes[i])
                edge_frames.append(self._edge_frames[i])
2108

2109
        metagraph = graph_index.from_edge_list(meta_edges, True)
2110
        # num_nodes_per_type doesn't need to be int32
2111
        hgidx = heterograph_index.create_heterograph_from_relations(
2112
2113
2114
            metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64"))
        hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes,
                            node_frames, edge_frames)
Minjie Wang's avatar
Minjie Wang committed
2115
        return hg
2116

Minjie Wang's avatar
Minjie Wang committed
2117
2118
    def edge_type_subgraph(self, etypes):
        """Return the subgraph induced on given edge types.
2119

Minjie Wang's avatar
Minjie Wang committed
2120
2121
        The metagraph of the returned subgraph is the subgraph of the original metagraph
        induced from the edge types.
2122

Minjie Wang's avatar
Minjie Wang committed
2123
        Features are shared with the original graph.
2124

Minjie Wang's avatar
Minjie Wang committed
2125
2126
2127
2128
        Parameters
        ----------
        etypes : list[str or tuple]
            The edge types
2129

Minjie Wang's avatar
Minjie Wang committed
2130
2131
2132
2133
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.edge_type_subgraph(['follows'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={}
              edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})

        Get the shared edge features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        node_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
2170
2171
        """
        etype_ids = [self.get_etype_id(etype) for etype in etypes]
2172
2173
        # meta graph is homograph, still using int64
        meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids, "int64"))
Minjie Wang's avatar
Minjie Wang committed
2174
2175
2176
        rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
        meta_src = meta_src.tonumpy()
        meta_dst = meta_dst.tonumpy()
2177
2178
2179
2180
        ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}
        mapped_meta_src = [ntypes_invmap[v] for v in meta_src]
        mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]
        node_frames = [self._node_frames[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
2181
        edge_frames = [self._edge_frames[i] for i in etype_ids]
2182
        induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
2183
        induced_etypes = [self._etypes[i] for i in etype_ids]   # get the "name" of edge type
2184
        num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
Minjie Wang's avatar
Minjie Wang committed
2185

2186
        metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
2187
        # num_nodes_per_type should be int64
2188
        hgidx = heterograph_index.create_heterograph_from_relations(
2189
            metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
Minjie Wang's avatar
Minjie Wang committed
2190
2191
2192
        hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
        return hg

2193
    def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
Minjie Wang's avatar
Minjie Wang committed
2194
        """Return the adjacency matrix of edges of the given edge type.
2195

Minjie Wang's avatar
Minjie Wang committed
2196
2197
        By default, a row of returned adjacency matrix represents the
        destination of an edge and the column represents the source.
2198

Minjie Wang's avatar
Minjie Wang committed
2199
2200
        When transpose is True, a row represents the source and a column
        represents a destination.
2201
2202
2203

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2204
2205
2206
2207
2208
        transpose : bool, optional
            A flag to transpose the returned adjacency matrix. (Default: False)
        ctx : context, optional
            The context of returned adjacency matrix. (Default: cpu)
        scipy_fmt : str, optional
Minjie Wang's avatar
Minjie Wang committed
2209
            If specified, return a scipy sparse matrix in the given format.
Mufei Li's avatar
Mufei Li committed
2210
            Otherwise, return a backend dependent sparse tensor. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2211
2212
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2213
            in the graph. (Default: None)
2214

Minjie Wang's avatar
Minjie Wang committed
2215
2216
2217
2218
        Returns
        -------
        SparseTensor or scipy.sparse.spmatrix
            Adjacency matrix.
Mufei Li's avatar
Mufei Li committed
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241

        Examples
        --------

        Instantiate a heterogeneous graph.

        >>> follows_g = dgl.graph([(0, 0), (1, 1)], 'user', 'follows')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 2)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, devs_g])

        Get a backend dependent sparse tensor. Here we use PyTorch for example.

        >>> g.adjacency_matrix(etype='develops')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)

        Get a scipy coo sparse matrix.

        >>> g.adjacency_matrix(scipy_fmt='coo', etype='develops')
        <3x2 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in COOrdinate format>
2242
        """
Mufei Li's avatar
Mufei Li committed
2243
2244
2245
2246
2247
2248
2249
        if transpose is None:
            dgl_warning(
                "Currently adjacency_matrix() returns a matrix with destination as rows"
                " by default.  In 0.5 the result will have source as rows"
                " (i.e. transpose=True)")
            transpose = False

Minjie Wang's avatar
Minjie Wang committed
2250
2251
2252
2253
2254
        etid = self.get_etype_id(etype)
        if scipy_fmt is None:
            return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
        else:
            return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
2255

Minjie Wang's avatar
Minjie Wang committed
2256
2257
    # Alias of ``adjacency_matrix``
    adj = adjacency_matrix
2258

Minjie Wang's avatar
Minjie Wang committed
2259
2260
2261
    def incidence_matrix(self, typestr, ctx=F.cpu(), etype=None):
        """Return the incidence matrix representation of edges with the given
        edge type.
2262

Mufei Li's avatar
Mufei Li committed
2263
        An incidence matrix is an n-by-m sparse matrix, where n is
Minjie Wang's avatar
Minjie Wang committed
2264
2265
2266
        the number of nodes and m is the number of edges. Each nnz
        value indicating whether the edge is incident to the node
        or not.
2267

Mufei Li's avatar
Mufei Li committed
2268
        There are three types of incidence matrices :math:`I`:
Da Zheng's avatar
Da Zheng committed
2269

Minjie Wang's avatar
Minjie Wang committed
2270
        * ``in``:
Da Zheng's avatar
Da Zheng committed
2271

Minjie Wang's avatar
Minjie Wang committed
2272
2273
2274
            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
              (or :math:`v` is the dst node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.
Da Zheng's avatar
Da Zheng committed
2275

Minjie Wang's avatar
Minjie Wang committed
2276
        * ``out``:
Da Zheng's avatar
Da Zheng committed
2277

Minjie Wang's avatar
Minjie Wang committed
2278
2279
2280
2281
2282
2283
2284
2285
2286
            - :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
              (or :math:`v` is the src node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.

        * ``both`` (only if source and destination node type are the same):

            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
            - :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
            - :math:`I[v, e] = 0` otherwise (including self-loop).
Da Zheng's avatar
Da Zheng committed
2287
2288
2289

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2290
2291
        typestr : str
            Can be either ``in``, ``out`` or ``both``
Mufei Li's avatar
Mufei Li committed
2292
2293
        ctx : context, optional
            The context of returned incidence matrix. (Default: cpu)
Minjie Wang's avatar
Minjie Wang committed
2294
2295
2296
2297
2298
2299
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
2300
        Framework SparseTensor
Minjie Wang's avatar
Minjie Wang committed
2301
            The incidence matrix.
Mufei Li's avatar
Mufei Li committed
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321

        Examples
        --------

        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
        >>> g.incidence_matrix('in')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('out')
        tensor(indices=tensor([[0, 1],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('both')
        tensor(indices=tensor([[1, 2],
                               [1, 1]]),
               values=tensor([-1.,  1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
Da Zheng's avatar
Da Zheng committed
2322
        """
Minjie Wang's avatar
Minjie Wang committed
2323
2324
2325
2326
2327
        etid = self.get_etype_id(etype)
        return self._graph.incidence_matrix(etid, typestr, ctx)[0]

    # Alias of ``incidence_matrix``
    inc = incidence_matrix
Da Zheng's avatar
Da Zheng committed
2328

Minjie Wang's avatar
Minjie Wang committed
2329
2330
2331
2332
2333
    #################################################################
    # Features
    #################################################################

    def node_attr_schemes(self, ntype=None):
Mufei Li's avatar
Mufei Li committed
2334
        """Return the node feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
2335
2336

        Each feature scheme is a named tuple that stores the shape and data type
2337
        of the node feature.
Da Zheng's avatar
Da Zheng committed
2338
2339
2340

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2341
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
2342
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
2343
2344
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
2345
2346
2347
2348
2349

        Returns
        -------
        dict of str to schemes
            The schemes of node feature columns.
2350
2351
2352
2353
2354

        Examples
        --------
        The following uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
2355
        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
2356
        >>> g.nodes['user'].data['h'] = torch.randn(3, 4)
2357
2358
        >>> g.node_attr_schemes('user')
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
2359
2360
2361
2362

        See Also
        --------
        edge_attr_schemes
Da Zheng's avatar
Da Zheng committed
2363
        """
Minjie Wang's avatar
Minjie Wang committed
2364
        return self._node_frames[self.get_ntype_id(ntype)].schemes
Da Zheng's avatar
Da Zheng committed
2365

Minjie Wang's avatar
Minjie Wang committed
2366
    def edge_attr_schemes(self, etype=None):
Mufei Li's avatar
Mufei Li committed
2367
        """Return the edge feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
2368
2369

        Each feature scheme is a named tuple that stores the shape and data type
2370
        of the edge feature.
Da Zheng's avatar
Da Zheng committed
2371
2372
2373

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2374
2375
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2376
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2377
2378
2379
2380

        Returns
        -------
        dict of str to schemes
Mufei Li's avatar
Mufei Li committed
2381
            The schemes of edge feature columns.
Da Zheng's avatar
Da Zheng committed
2382

2383
2384
2385
        Examples
        --------
        The following uses PyTorch backend.
Da Zheng's avatar
Da Zheng committed
2386

Mufei Li's avatar
Mufei Li committed
2387
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
2388
        >>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
2389
2390
        >>> g.edge_attr_schemes(('user', 'plays', 'game'))
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
2391
2392
2393
2394

        See Also
        --------
        node_attr_schemes
Da Zheng's avatar
Da Zheng committed
2395
        """
Minjie Wang's avatar
Minjie Wang committed
2396
        return self._edge_frames[self.get_etype_id(etype)].schemes
Da Zheng's avatar
Da Zheng committed
2397

Minjie Wang's avatar
Minjie Wang committed
2398
2399
    def set_n_initializer(self, initializer, field=None, ntype=None):
        """Set the initializer for empty node features.
Da Zheng's avatar
Da Zheng committed
2400

Minjie Wang's avatar
Minjie Wang committed
2401
2402
        Initializer is a callable that returns a tensor given the shape, data type
        and device context.
Da Zheng's avatar
Da Zheng committed
2403

Minjie Wang's avatar
Minjie Wang committed
2404
        When a subset of the nodes are assigned a new feature, initializer is
Mufei Li's avatar
Mufei Li committed
2405
        used to create feature for the rest of the nodes.
Minjie Wang's avatar
Minjie Wang committed
2406
2407
2408
2409

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
2410
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
2411
        field : str, optional
Mufei Li's avatar
Mufei Li committed
2412
            The feature field name. Default is to set an initializer for all the
Minjie Wang's avatar
Minjie Wang committed
2413
2414
            feature fields.
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
2415
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
2416
2417
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
2418

Minjie Wang's avatar
Minjie Wang committed
2419
2420
2421
2422
        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Da Zheng's avatar
Da Zheng committed
2423

Mufei Li's avatar
Mufei Li committed
2424
2425
2426
        See Also
        --------
        set_e_initializer
Da Zheng's avatar
Da Zheng committed
2427
        """
Minjie Wang's avatar
Minjie Wang committed
2428
2429
        ntid = self.get_ntype_id(ntype)
        self._node_frames[ntid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2430

Minjie Wang's avatar
Minjie Wang committed
2431
2432
    def set_e_initializer(self, initializer, field=None, etype=None):
        """Set the initializer for empty edge features.
Da Zheng's avatar
Da Zheng committed
2433

Minjie Wang's avatar
Minjie Wang committed
2434
2435
2436
2437
2438
2439
2440
2441
2442
        Initializer is a callable that returns a tensor given the shape, data
        type and device context.

        When a subset of the edges are assigned a new feature, initializer is
        used to create feature for rest of the edges.

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
2443
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
2444
2445
2446
2447
2448
        field : str, optional
            The feature field name. Default is set an initializer for all the
            feature fields.
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2449
2450
            in the graph. Error will be raised otherwise.
            (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2451
2452
2453
2454
2455

        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Mufei Li's avatar
Mufei Li committed
2456
2457
2458
2459

        See Also
        --------
        set_n_initializer
Da Zheng's avatar
Da Zheng committed
2460
        """
Minjie Wang's avatar
Minjie Wang committed
2461
2462
        etid = self.get_etype_id(etype)
        self._edge_frames[etid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2463

Minjie Wang's avatar
Minjie Wang committed
2464
2465
    def _set_n_repr(self, ntid, u, data, inplace=False):
        """Internal API to set node features.
Da Zheng's avatar
Da Zheng committed
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
        and (D1, D2, ...) be the shape of the node representation tensor. The
        length of the given node ids must match B (i.e, len(u) == B).

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2477
2478
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2479
2480
        u : node, container or tensor
            The node(s).
Minjie Wang's avatar
Minjie Wang committed
2481
2482
2483
        data : dict of tensor
            Node representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2484
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2485
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2486
        """
2487
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2488
            num_nodes = self._graph.number_of_nodes(ntid)
2489
        else:
2490
            u = utils.toindex(u, self._idtype_str)
2491
2492
2493
2494
2495
2496
2497
2498
2499
            num_nodes = len(u)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_nodes:
                raise DGLError('Expect number of features to match number of nodes (len(u)).'
                               ' Got %d and %d instead.' % (nfeats, num_nodes))

        if is_all(u):
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2500
                self._node_frames[ntid][key] = val
2501
        else:
Minjie Wang's avatar
Minjie Wang committed
2502
            self._node_frames[ntid].update_rows(u, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2503

Minjie Wang's avatar
Minjie Wang committed
2504
    def _get_n_repr(self, ntid, u):
Da Zheng's avatar
Da Zheng committed
2505
2506
2507
2508
2509
2510
        """Get node(s) representation of a single node type.

        The returned feature tensor batches multiple node features on the first dimension.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2511
2512
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2513
2514
2515
2516
2517
2518
2519
2520
        u : node, container or tensor
            The node(s).

        Returns
        -------
        dict
            Representation dict from feature name to feature tensor.
        """
2521
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2522
            return dict(self._node_frames[ntid])
2523
        else:
2524
            u = utils.toindex(u, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2525
            return self._node_frames[ntid].select_rows(u)
Da Zheng's avatar
Da Zheng committed
2526

Minjie Wang's avatar
Minjie Wang committed
2527
2528
    def _pop_n_repr(self, ntid, key):
        """Internal API to get and remove the specified node feature.
Da Zheng's avatar
Da Zheng committed
2529
2530
2531

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2532
2533
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2534
2535
2536
2537
2538
2539
2540
2541
        key : str
            The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2542
        return self._node_frames[ntid].pop(key)
Da Zheng's avatar
Da Zheng committed
2543

Minjie Wang's avatar
Minjie Wang committed
2544
2545
    def _set_e_repr(self, etid, edges, data, inplace=False):
        """Internal API to set edge(s) features.
Da Zheng's avatar
Da Zheng committed
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
        and (D1, D2, ...) be the shape of the edge representation tensor.

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2556
2557
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2558
2559
2560
2561
2562
2563
2564
2565
        edges : edges
            Edges can be either

            * A pair of endpoint nodes (u, v), where u is the node ID of source
              node type and v is that of destination node type.
            * A tensor of edge ids of the given type.

            The default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2566
2567
2568
        data : tensor or dict of tensor
            Edge representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2569
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2570
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2571
        """
2572
2573
2574
2575
2576
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
2577
2578
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
2579
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2580
            _, _, eid = self._graph.edge_ids(etid, u, v)
2581
        else:
2582
            eid = utils.toindex(edges, self._idtype_str)
2583
2584
2585
2586
2587
2588
2589

        # sanity check
        if not utils.is_dict_like(data):
            raise DGLError('Expect dictionary type for feature data.'
                           ' Got "%s" instead.' % type(data))

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2590
            num_edges = self._graph.number_of_edges(etid)
2591
        else:
2592
            eid = utils.toindex(eid, self._idtype_str)
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
            num_edges = len(eid)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_edges:
                raise DGLError('Expect number of features to match number of edges.'
                               ' Got %d and %d instead.' % (nfeats, num_edges))
        # set
        if is_all(eid):
            # update column
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2603
                self._edge_frames[etid][key] = val
2604
2605
        else:
            # update row
Minjie Wang's avatar
Minjie Wang committed
2606
            self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2607

Minjie Wang's avatar
Minjie Wang committed
2608
2609
    def _get_e_repr(self, etid, edges):
        """Internal API to get edge features.
Da Zheng's avatar
Da Zheng committed
2610
2611
2612

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2613
2614
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2615
2616
2617
2618
2619
2620
2621
2622
2623
        edges : edges
            Edges can be a pair of endpoint nodes (u, v), or a
            tensor of edge ids. The default value is all the edges.

        Returns
        -------
        dict
            Representation dict
        """
2624
2625
2626
2627
2628
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
2629
2630
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
2631
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2632
            _, _, eid = self._graph.edge_ids(etid, u, v)
2633
        else:
2634
            eid = utils.toindex(edges, self._idtype_str)
2635
2636

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2637
            return dict(self._edge_frames[etid])
2638
        else:
2639
            eid = utils.toindex(eid, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2640
            return self._edge_frames[etid].select_rows(eid)
Da Zheng's avatar
Da Zheng committed
2641

Minjie Wang's avatar
Minjie Wang committed
2642
    def _pop_e_repr(self, etid, key):
Da Zheng's avatar
Da Zheng committed
2643
2644
2645
2646
        """Get and remove the specified edge repr of a single edge type.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2647
2648
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2649
2650
2651
2652
2653
2654
2655
2656
        key : str
          The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2657
        self._edge_frames[etid].pop(key)
Da Zheng's avatar
Da Zheng committed
2658

Minjie Wang's avatar
Minjie Wang committed
2659
2660
2661
2662
2663
2664
2665
    #################################################################
    # Message passing
    #################################################################

    def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
        """Apply the function on the nodes with the same type to update their
        features.
Da Zheng's avatar
Da Zheng committed
2666

Minjie Wang's avatar
Minjie Wang committed
2667
        If None is provided for ``func``, nothing will happen.
Da Zheng's avatar
Da Zheng committed
2668
2669
2670

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2671
        func : callable or None
Minjie Wang's avatar
Minjie Wang committed
2672
2673
2674
            Apply function on the nodes. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        v : int or iterable of int or tensor, optional
Mufei Li's avatar
Mufei Li committed
2675
            The (type-specific) node (ids) on which to apply ``func``. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2676
2677
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
2678
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2679
2680
        inplace : bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2681
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2682
2683
2684

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2685
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
2686
2687
2688
        >>> g.nodes['user'].data['h'] = torch.ones(3, 5)
        >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
        >>> g.nodes['user'].data['h']
Da Zheng's avatar
Da Zheng committed
2689
2690
2691
        tensor([[2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2692
2693
2694
2695

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2696
        """
2697
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
2698
2699
        ntid = self.get_ntype_id(ntype)
        if is_all(v):
2700
            v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2701
        else:
2702
            v_ntype = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2703
2704
        with ir.prog() as prog:
            scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
2705
                                           inplace=inplace, ntype=self._ntypes[ntid])
Minjie Wang's avatar
Minjie Wang committed
2706
2707
2708
            Runtime.run(prog)

    def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2709
2710
2711
2712
2713
2714
2715
        """Apply the function on the edges with the same type to update their
        features.

        If None is provided for ``func``, nothing will happen.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2716
        func : callable or None
Da Zheng's avatar
Da Zheng committed
2717
2718
            Apply function on the edge. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Mufei Li's avatar
Mufei Li committed
2719
        edges : optional
Da Zheng's avatar
Da Zheng committed
2720
            Edges on which to apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2721
            edge specification. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2722
2723
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2724
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2725
2726
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2727
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2728
2729
2730

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2731
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
2732
2733
2734
        >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
        >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
        >>> g.edges[('user', 'plays', 'game')].data['h']
Da Zheng's avatar
Da Zheng committed
2735
        tensor([[2., 2., 2., 2., 2.],
2736
                [2., 2., 2., 2., 2.],
Da Zheng's avatar
Da Zheng committed
2737
2738
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2739
2740
2741
2742
2743

        See Also
        --------
        apply_nodes
        group_apply_edges
Da Zheng's avatar
Da Zheng committed
2744
        """
2745
        check_same_dtype(self._idtype_str, edges)
Minjie Wang's avatar
Minjie Wang committed
2746
2747
2748
2749
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
2750
            eid = utils.toindex(slice(0, self.number_of_edges(etype)), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2751
2752
        elif isinstance(edges, tuple):
            u, v = edges
2753
2754
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2755
2756
2757
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
2758
            eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2759
2760
2761
2762
2763
2764
2765
2766
2767
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_apply_edges(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid, func, inplace=inplace)
            Runtime.run(prog)

    def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2768
2769
2770
2771
2772
2773
2774
        """Group the edges by nodes and apply the function of the grouped
        edges to update their features.  The edges are of the same edge type
        (hence having the same source and destination node type).

        Parameters
        ----------
        group_by : str
Mufei Li's avatar
Mufei Li committed
2775
            Specify how to group edges. Expected to be either ``'src'`` or ``'dst'``
Minjie Wang's avatar
Minjie Wang committed
2776
        func : callable
Mufei Li's avatar
Mufei Li committed
2777
2778
2779
2780
2781
            Apply function on the edge. The function should be an
            :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should be
            (bucket_size, degrees, *feature_shape), and return the dict
            with values of the same shapes.
        edges : optional
Da Zheng's avatar
Da Zheng committed
2782
            Edges on which to group and apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2783
            edge specification. Default is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2784
2785
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2786
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2787
2788
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
            (Default: False)

        Examples
        --------
        >>> g = dgl.graph([(0, 1), (0, 2), (1, 2)], 'user', 'follows')
        >>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
        >>> def softmax_feat(edges):
        >>>     return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
        >>> g.group_apply_edges(group_by='src', func=softmax_feat)
        >>> g.edata['norm_feat']
        tensor([[0.3796],
                [0.6204],
                [1.0000]])

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2806
        """
2807
        check_same_dtype(self._idtype_str, edges)
2808
2809
2810
        if group_by not in ('src', 'dst'):
            raise DGLError("Group_by should be either src or dst")

Minjie Wang's avatar
Minjie Wang committed
2811
2812
2813
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2814
            u, v, _ = self._graph.edges(etid, 'eid')
2815
            eid = utils.toindex(slice(0, self.number_of_edges(etype)), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2816
2817
        elif isinstance(edges, tuple):
            u, v = edges
2818
2819
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2820
2821
2822
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
2823
            eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_group_apply_edge(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                func, group_by,
                inplace=inplace)
            Runtime.run(prog)

    def send(self, edges, message_func, etype=None):
Da Zheng's avatar
Da Zheng committed
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
        """Send messages along the given edges with the same edge type.

        ``edges`` can be any of the following types:

        * ``int`` : Specify one edge using its edge id (of the given edge type).
        * ``pair of int`` : Specify one edge using its endpoints (of source node type
          and destination node type respectively).
        * ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids.
        * ``pair of int iterable`` / ``pair of tensors`` :
          Specify multiple edges using their endpoints.
2845

Mufei Li's avatar
Mufei Li committed
2846
        **Only works if the graph has one edge type.** For multiple types, use
2847
2848
2849
2850

        .. code::

           g['edgetype'].send(edges, message_func)
Da Zheng's avatar
Da Zheng committed
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860

        The UDF returns messages on the edges and can be later fetched in
        the destination node's ``mailbox``. Receiving will consume the messages.
        See :func:`recv` for example.

        If multiple ``send`` are triggered on the same edge without ``recv``. Messages
        generated by the later ``send`` will overwrite previous messages.

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2861
2862
        edges : optional
            Edges on which to apply ``message_func``.
2863
        message_func : callable
Da Zheng's avatar
Da Zheng committed
2864
2865
2866
2867
2868
2869
2870
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.

        Notes
        -----
        On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
        along all edges between :math:`u` and :math:`v`.
Mufei Li's avatar
Mufei Li committed
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891

        Examples
        --------

        >>> import dgl.function as fn
        >>> import torch
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Different ways for sending messages.

        >>> # Send the feature of source nodes along all edges
        >>> g.send(g.edges(), fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its id
        >>> g.send(0, fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its end points
        >>> g.send((0, 1), fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their ids
        >>> g.send([0, 1], fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their end points
        >>> g.send(([0, 1], [1, 2]), fn.copy_src('h', 'm'))
Da Zheng's avatar
Da Zheng committed
2892
        """
2893
        check_same_dtype(self._idtype_str, edges)
2894
        assert message_func is not None
Minjie Wang's avatar
Minjie Wang committed
2895
2896
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2897
2898

        if is_all(edges):
2899
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)), self._idtype_str)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2900
            u, v, _ = self._graph.edges(etid, 'eid')
2901
2902
        elif isinstance(edges, tuple):
            u, v = edges
2903
2904
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
2905
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2906
            u, v, eid = self._graph.edge_ids(etid, u, v)
2907
        else:
2908
            eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
2909
            u, v, _ = self._graph.find_edges(etid, eid)
2910
2911
2912
2913
2914
2915

        if len(eid) == 0:
            # no edge to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2916
2917
2918
2919
            scheduler.schedule_send(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                message_func)
2920
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2921
2922

    def recv(self,
Minjie Wang's avatar
Minjie Wang committed
2923
2924
             v,
             reduce_func,
Da Zheng's avatar
Da Zheng committed
2925
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
2926
             etype=None,
Da Zheng's avatar
Da Zheng committed
2927
             inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2928
        r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
Da Zheng's avatar
Da Zheng committed
2929

Minjie Wang's avatar
Minjie Wang committed
2930
2931
2932
        It calculates:

        .. math::
Mufei Li's avatar
Mufei Li committed
2933
            h_v^{new} = \sigma(f(\{m_{uv} | u\in\mathcal{N}_{t}(v)\}))
Minjie Wang's avatar
Minjie Wang committed
2934

Mufei Li's avatar
Mufei Li committed
2935
2936
        where :math:`\mathcal{N}_t(v)` defines the predecessors of node(s) :math:`v` connected by
        edges of type :math:`t`, and :math:`m_{uv}` is the message on edge :math:`(u,v)`.
Minjie Wang's avatar
Minjie Wang committed
2937

Mufei Li's avatar
Mufei Li committed
2938
2939
        * ``reduce_func`` specifies :math:`f`, e.g. summation or average.
        * ``apply_node_func`` specifies :math:`\sigma`, e.g. ReLU activation.
Minjie Wang's avatar
Minjie Wang committed
2940
2941

        Other notes:
Da Zheng's avatar
Da Zheng committed
2942
2943
2944
2945
2946
2947

        * `reduce_func` will be skipped for nodes with no incoming message.
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.
Minjie Wang's avatar
Minjie Wang committed
2948
2949
        * The node features will be updated by the result of the ``reduce_func``.
        * Messages are consumed once received.
Mufei Li's avatar
Mufei Li committed
2950
        * The provided UDF may be called multiple times so it is recommended to provide
Minjie Wang's avatar
Minjie Wang committed
2951
          function with no side effect.
2952

Da Zheng's avatar
Da Zheng committed
2953
2954
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2955
        v : int, container or tensor
Mufei Li's avatar
Mufei Li committed
2956
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
2957
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
2958
2959
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
2960
        apply_node_func : callable
Da Zheng's avatar
Da Zheng committed
2961
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2962
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2963
2964
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2965
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2966
2967
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
2992
        """
2993
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
2994
2995
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2996
        if is_all(v):
2997
            v = F.arange(0, self.number_of_nodes(dtid), self._idtype_str)
2998
2999
        elif isinstance(v, int):
            v = [v]
3000
        v = utils.toindex(v, dtype=self._idtype_str)
3001
3002
3003
3004
        if len(v) == 0:
            # no vertex to be triggered.
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3005
3006
            scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v, reduce_func, apply_node_func,
3007
3008
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3009

Mufei Li's avatar
Mufei Li committed
3010
    def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
3011
3012
3013
3014
3015
3016
        r"""Receive messages from multiple edge types and perform aggregation.

        It calculates:

        .. math::

Mufei Li's avatar
Mufei Li committed
3017
3018
3019
3020
            \begin{align}
            h_{v, t}^{new} &= f\left(\left\{m_{uv} | u\in\mathcal{N}_{t}(v)\right\}\right)\\
            h_v^{new} &= \sigma\left(g\left(\left\{h_{v, t}^{new} | t\in T_e\right\}\right)\right)
            \end{align}
Minjie Wang's avatar
Minjie Wang committed
3021

Mufei Li's avatar
Mufei Li committed
3022
3023
3024
3025
        * ``per_type_reducer`` is a dictionary mapping edge type (str or tuple of str) to
          reduce functions :math:`f` of each type.
        * ``cross_reducer`` specifies :math:`g`.
        * ``apply_node_func`` specifies :math:`\sigma`.
Minjie Wang's avatar
Minjie Wang committed
3026
3027
3028
3029
3030

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
3031
3032
        reducer_dict : dict of callable
            Mapping edge type (str or tuple of str) to reduce function (:mod:`Node UDF <dgl.udf>`).
Minjie Wang's avatar
Minjie Wang committed
3033
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3034
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3035
3036
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3037
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3038
3039
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'), {'follows': fn.sum('m', 'h'),
        >>>              'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
3066
        """
3067
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
3068
3069
        # infer receive node type
        ntype = infer_ntype_from_dict(self, reducer_dict)
3070
        ntid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3071
        if is_all(v):
3072
            v = F.arange(0, self.number_of_nodes(ntid), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3073
3074
        elif isinstance(v, int):
            v = [v]
3075
        v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3076
3077
3078
3079
3080
        if len(v) == 0:
            return
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
3081
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3082
3083
3084
3085
3086
3087
        with ir.prog() as prog:
            for ety, args in reducer_dict.items():
                outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
                args = pad_tuple(args, 2)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be either '
Mufei Li's avatar
Mufei Li committed
3088
                                   '(1) reduce_func or (2) (reduce_func, apply_node_func)')
Minjie Wang's avatar
Minjie Wang committed
3089
3090
3091
3092
3093
3094
3095
                rfunc, afunc = args
                etid = self.get_etype_id(ety)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3096
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3097
3098
            Runtime.run(prog)
        # merge by cross_reducer
3099
        self._node_frames[ntid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3100
        # apply
Mufei Li's avatar
Mufei Li committed
3101
3102
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3103

Da Zheng's avatar
Da Zheng committed
3104
3105
    def send_and_recv(self,
                      edges,
Minjie Wang's avatar
Minjie Wang committed
3106
3107
                      message_func,
                      reduce_func,
3108
                      apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3109
                      etype=None,
Da Zheng's avatar
Da Zheng committed
3110
                      inplace=False):
Mufei Li's avatar
Mufei Li committed
3111
        """Send messages along edges of the specified type, and let destinations
Da Zheng's avatar
Da Zheng committed
3112
3113
        receive them.

Mufei Li's avatar
Mufei Li committed
3114
        Optionally, apply a function to update the node features after "receive".
Da Zheng's avatar
Da Zheng committed
3115
3116

        This is a convenient combination for performing
Mufei Li's avatar
Mufei Li committed
3117
3118
        :mod:`send <dgl.DGLHeteroGraph.send>` along the ``edges`` and
        :mod:`recv <dgl.DGLHeteroGraph.recv>` for the destinations of the ``edges``.
Da Zheng's avatar
Da Zheng committed
3119

Mufei Li's avatar
Mufei Li committed
3120
        **Only works if the graph has one edge type.**  For multiple types, use
3121
3122
3123

        .. code::

Mufei Li's avatar
Mufei Li committed
3124
3125
           g['edgetype'].send_and_recv(edges, message_func, reduce_func,
                                       apply_node_func, inplace=inplace)
3126

Da Zheng's avatar
Da Zheng committed
3127
3128
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3129
3130
        edges : See :func:`send` for valid edge specification.
            Edges on which to apply ``func``.
Minjie Wang's avatar
Minjie Wang committed
3131
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3132
3133
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3134
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3135
3136
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3137
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3138
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3139
3140
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
3141
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3142
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3143
3144
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])

        Trigger "send" and "receive" separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])

        Trigger "send" and "receive" in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send_and_recv(g['follows'].edges(), fn.copy_src('h', 'm'),
        >>>                 fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
3179
        """
Minjie Wang's avatar
Minjie Wang committed
3180
3181
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3182
3183
3184

        if isinstance(edges, tuple):
            u, v = edges
3185
3186
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
3187
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
3188
            u, v, eid = self._graph.edge_ids(etid, u, v)
3189
        else:
3190
            eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3191
            u, v, _ = self._graph.find_edges(etid, eid)
3192
3193
3194
3195
3196
3197

        if len(u) == 0:
            # no edges to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3198
3199
3200
            scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                   (u, v, eid),
                                   message_func, reduce_func, apply_node_func,
3201
3202
                                   inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3203

Mufei Li's avatar
Mufei Li committed
3204
    def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
3205
3206
        r"""Send and receive messages along multiple edge types and perform aggregation.

Mufei Li's avatar
Mufei Li committed
3207
        Optionally, apply a function to update the node features after "receive".
Minjie Wang's avatar
Minjie Wang committed
3208

Mufei Li's avatar
Mufei Li committed
3209
3210
3211
        This is a convenient combination for performing multiple
        :mod:`send <dgl.DGLHeteroGraph.send>` along edges of different types and
        :mod:`multi_recv <dgl.DGLHeteroGraph.multi_recv>` for the destinations of all edges.
Minjie Wang's avatar
Minjie Wang committed
3212
3213
3214

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (4-tuples). Each 4-tuple represents
            (edges, msg_func, reduce_func, apply_node_func):

            * edges: See send() for valid edge specification.
                  Edges on which to pass messages.
            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the node. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3231
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3232
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3233
3234
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3235
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3236
3237
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Trigger send and recv separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'),
        >>>              {'follows': fn.sum('m', 'h'), 'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])

        Trigger “send” and “receive” in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_send_and_recv(
        >>>     {'follows': (g['follows'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (g['attracts'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
3276
3277
3278
        """
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
3279
        dtid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3280
3281
3282
3283
3284

        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
        all_vs = []
3285
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3286
3287
3288
3289
3290
3291
3292
3293
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 4)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3294
                                   '(edges, msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3295
3296
3297
                edges, mfunc, rfunc, afunc = args
                if isinstance(edges, tuple):
                    u, v = edges
3298
3299
                    u = utils.toindex(u, self._idtype_str)
                    v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3300
3301
3302
                    # Rewrite u, v to handle edge broadcasting and multigraph.
                    u, v, eid = self._graph.edge_ids(etid, u, v)
                else:
3303
                    eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
                    u, v, _ = self._graph.find_edges(etid, eid)
                all_vs.append(v)
                if len(u) == 0:
                    # no edges to be triggered
                    continue
                scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                       (u, v, eid),
                                       mfunc, rfunc, afunc,
                                       inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3314
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3315
3316
            Runtime.run(prog)
        # merge by cross_reducer
3317
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3318
        # apply
Mufei Li's avatar
Mufei Li committed
3319
        if apply_node_func is not None:
Minjie Wang's avatar
Minjie Wang committed
3320
            dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
Mufei Li's avatar
Mufei Li committed
3321
            self.apply_nodes(apply_node_func, dstnodes, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3322

Da Zheng's avatar
Da Zheng committed
3323
3324
    def pull(self,
             v,
Minjie Wang's avatar
Minjie Wang committed
3325
3326
             message_func,
             reduce_func,
3327
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3328
             etype=None,
Da Zheng's avatar
Da Zheng committed
3329
3330
3331
3332
3333
             inplace=False):
        """Pull messages from the node(s)' predecessors and then update their features.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
3334
3335
3336
3337
3338
3339
        This is equivalent to :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>`
        on the incoming edges of ``v`` with the specified type.

        Other notes:

        * `reduce_func` will be skipped for nodes with no incoming messages.
Da Zheng's avatar
Da Zheng committed
3340
3341
3342
3343
3344
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.

Mufei Li's avatar
Mufei Li committed
3345
        **Only works if the graph has one edge type.** For multiple types, use
3346
3347
3348

        .. code::

Mufei Li's avatar
Mufei Li committed
3349
           g['edgetype'].pull(v, message_func, reduce_func, apply_node_func, inplace=inplace)
3350

Da Zheng's avatar
Da Zheng committed
3351
3352
        Parameters
        ----------
3353
        v : int, container or tensor, optional
Mufei Li's avatar
Mufei Li committed
3354
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
3355
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3356
3357
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3358
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3359
3360
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3361
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3362
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3363
3364
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
3365
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3366
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3367
3368
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Pull.

        >>> g['follows'].pull(2, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
3392
        """
3393
        check_same_dtype(self._idtype_str, v)
Minjie Wang's avatar
Minjie Wang committed
3394
3395
3396
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3397

3398
        v = utils.toindex(v, self._idtype_str)
3399
3400
3401
        if len(v) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3402
3403
3404
            scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v,
                                    message_func, reduce_func, apply_node_func,
3405
3406
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3407

Mufei Li's avatar
Mufei Li committed
3408
    def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
3409
3410
3411
        r"""Pull and receive messages of the given nodes along multiple edge types
        and perform aggregation.

Mufei Li's avatar
Mufei Li committed
3412
3413
        This is equivalent to :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        on the incoming edges of ``v`` with the specified types.
Minjie Wang's avatar
Minjie Wang committed
3414
3415
3416
3417
3418

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3433
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3434
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3435
3436
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3437
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3438
3439
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3440
            (Default: False)
Minjie Wang's avatar
Minjie Wang committed
3441

Mufei Li's avatar
Mufei Li committed
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(1, 1), (1, 0)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Pull.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_pull(1,
        >>>              {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>               'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [3.]])
        """
3467
3468
        check_same_dtype(self._idtype_str, v)
        v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3469
3470
3471
3472
        if len(v) == 0:
            return
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
3473
        dtid = self.get_ntype_id_from_dst(ntype)
Minjie Wang's avatar
Minjie Wang committed
3474
3475
3476
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
3477
        merge_order = []
Minjie Wang's avatar
Minjie Wang committed
3478
3479
3480
3481
3482
3483
3484
3485
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3486
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3487
3488
3489
3490
3491
3492
                mfunc, rfunc, afunc = args
                scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v,
                                        mfunc, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
3493
                merge_order.append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3494
3495
            Runtime.run(prog)
        # merge by cross_reducer
3496
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
Minjie Wang's avatar
Minjie Wang committed
3497
        # apply
Mufei Li's avatar
Mufei Li committed
3498
3499
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3500

Da Zheng's avatar
Da Zheng committed
3501
3502
    def push(self,
             u,
Minjie Wang's avatar
Minjie Wang committed
3503
3504
             message_func,
             reduce_func,
3505
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3506
             etype=None,
Da Zheng's avatar
Da Zheng committed
3507
3508
3509
             inplace=False):
        """Send message from the node(s) to their successors and update them.

Mufei Li's avatar
Mufei Li committed
3510
3511
3512
        This is equivalent to performing
        :mod:`send_and_recv <DGLHeteroGraph.send_and_recv>` along the outbound
        edges from ``u``.
Da Zheng's avatar
Da Zheng committed
3513

Mufei Li's avatar
Mufei Li committed
3514
        **Only works if the graph has one edge type.** For multiple types, use
3515
3516
3517

        .. code::

Mufei Li's avatar
Mufei Li committed
3518
           g['edgetype'].push(u, message_func, reduce_func, apply_node_func, inplace=inplace)
3519

Da Zheng's avatar
Da Zheng committed
3520
3521
        Parameters
        ----------
3522
        u : int, container or tensor
Mufei Li's avatar
Mufei Li committed
3523
            The node(s) to push out messages.
Minjie Wang's avatar
Minjie Wang committed
3524
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3525
3526
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3527
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3528
3529
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3530
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3531
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3532
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3533
3534
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3535
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3536
3537
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (0, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Push.

        >>> g['follows'].push(0, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [0.]])
Da Zheng's avatar
Da Zheng committed
3559
        """
3560
        check_same_dtype(self._idtype_str, u)
Minjie Wang's avatar
Minjie Wang committed
3561
3562
3563
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3564

3565
        u = utils.toindex(u, self._idtype_str)
3566
3567
3568
        if len(u) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3569
3570
3571
            scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    u,
                                    message_func, reduce_func, apply_node_func,
3572
3573
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3574
3575

    def update_all(self,
Minjie Wang's avatar
Minjie Wang committed
3576
3577
3578
3579
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Da Zheng's avatar
Da Zheng committed
3580
3581
3582
3583
        """Send messages through all edges and update all nodes.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
3584
3585
3586
        This is equivalent to
        :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>` over all edges
        of the specified type.
Da Zheng's avatar
Da Zheng committed
3587

Mufei Li's avatar
Mufei Li committed
3588
        **Only works if the graph has one edge type.** For multiple types, use
3589
3590
3591
3592
3593

        .. code::

           g['edgetype'].update_all(message_func, reduce_func, apply_node_func)

Da Zheng's avatar
Da Zheng committed
3594
3595
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3596
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3597
3598
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3599
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3600
3601
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3602
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3603
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3604
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3605
3606
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (1, 2), (2, 2)], 'user', 'follows')

        Update all.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3627
        """
Minjie Wang's avatar
Minjie Wang committed
3628
3629
3630
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3631
3632

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3633
3634
3635
            scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                          message_func, reduce_func,
                                          apply_node_func)
3636
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3637

Mufei Li's avatar
Mufei Li committed
3638
    def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
Minjie Wang's avatar
Minjie Wang committed
3639
        r"""Send and receive messages along all edges.
Da Zheng's avatar
Da Zheng committed
3640

Mufei Li's avatar
Mufei Li committed
3641
3642
3643
        This is equivalent to
        :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        over all edges.
Da Zheng's avatar
Da Zheng committed
3644
3645
3646

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3661
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3662
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3663
3664
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3665
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3666
3667
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3668
            (Default: False)
Da Zheng's avatar
Da Zheng committed
3669

Mufei Li's avatar
Mufei Li committed
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
        etype_dict : dict of callable
            ``update_all`` arguments per edge type.

        Examples
        --------
        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Update all.

        >>> g.multi_update_all(
        >>>     {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [4.]])
        """
Minjie Wang's avatar
Minjie Wang committed
3697
3698
3699
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = defaultdict(list)
3700
        merge_order = defaultdict(list)
Minjie Wang's avatar
Minjie Wang committed
3701
3702
3703
3704
3705
3706
3707
3708
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3709
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3710
3711
3712
3713
3714
                mfunc, rfunc, afunc = args
                scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                              mfunc, rfunc, afunc,
                                              outframe=outframe)
                all_out[dtid].append(outframe)
3715
                merge_order[dtid].append(etid)  # use edge type id as merge order hint
Minjie Wang's avatar
Minjie Wang committed
3716
3717
3718
            Runtime.run(prog)
        for dtid, frames in all_out.items():
            # merge by cross_reducer
3719
3720
            self._node_frames[dtid].update(
                merge_frames(frames, cross_reducer, merge_order[dtid]))
Minjie Wang's avatar
Minjie Wang committed
3721
            # apply
Mufei Li's avatar
Mufei Li committed
3722
3723
            if apply_node_func is not None:
                self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False)
Minjie Wang's avatar
Minjie Wang committed
3724
3725
3726
3727
3728
3729
3730

    def prop_nodes(self,
                   nodes_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3731
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3732
3733
3734
3735
3736
3737
        :func:`pull()` on nodes.

        The traversal order is specified by the ``nodes_generator``. It generates
        node frontiers, which is a list or a tensor of nodes. The nodes in the
        same frontier will be triggered together, while nodes in different frontiers
        will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3738
3739
3740

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3741
        nodes_generator : iterable, each element is a list or a tensor of node ids
Minjie Wang's avatar
Minjie Wang committed
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
            The generator of node frontiers. It specifies which nodes perform
            :func:`pull` at each timestep.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3752
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3753
3754
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3774

Minjie Wang's avatar
Minjie Wang committed
3775
3776
3777
        See Also
        --------
        prop_edges
Da Zheng's avatar
Da Zheng committed
3778
        """
Minjie Wang's avatar
Minjie Wang committed
3779
3780
        for node_frontier in nodes_generator:
            self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3781

Minjie Wang's avatar
Minjie Wang committed
3782
3783
3784
3785
3786
3787
    def prop_edges(self,
                   edges_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3788
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3789
        :func:`send_and_recv()` on edges.
Da Zheng's avatar
Da Zheng committed
3790

Minjie Wang's avatar
Minjie Wang committed
3791
3792
3793
        The traversal order is specified by the ``edges_generator``. It generates
        edge frontiers. The edge frontiers should be of *valid edges type*.
        See :func:`send` for more details.
Da Zheng's avatar
Da Zheng committed
3794

Mufei Li's avatar
Mufei Li committed
3795
        Edges in the same frontier will be triggered together, and edges in
Minjie Wang's avatar
Minjie Wang committed
3796
        different frontiers will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3797
3798
3799

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
        edges_generator : generator
            The generator of edge frontiers.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3810
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3811
3812
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3833

Minjie Wang's avatar
Minjie Wang committed
3834
3835
3836
        See Also
        --------
        prop_nodes
Da Zheng's avatar
Da Zheng committed
3837
        """
Minjie Wang's avatar
Minjie Wang committed
3838
3839
3840
        for edge_frontier in edges_generator:
            self.send_and_recv(edge_frontier, message_func, reduce_func,
                               apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3841

Minjie Wang's avatar
Minjie Wang committed
3842
3843
3844
    #################################################################
    # Misc
    #################################################################
Da Zheng's avatar
Da Zheng committed
3845

Minjie Wang's avatar
Minjie Wang committed
3846
3847
    def to_networkx(self, node_attrs=None, edge_attrs=None):
        """Convert this graph to networkx graph.
Da Zheng's avatar
Da Zheng committed
3848

Minjie Wang's avatar
Minjie Wang committed
3849
        The edge id will be saved as the 'id' edge attribute.
Da Zheng's avatar
Da Zheng committed
3850
3851
3852

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3853
3854
3855
3856
        node_attrs : iterable of str, optional
            The node attributes to be copied.
        edge_attrs : iterable of str, optional
            The edge attributes to be copied.
Da Zheng's avatar
Da Zheng committed
3857
3858
3859

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
3860
3861
        networkx.DiGraph
            The nx graph
Da Zheng's avatar
Da Zheng committed
3862

Minjie Wang's avatar
Minjie Wang committed
3863
3864
        Examples
        --------
Da Zheng's avatar
Da Zheng committed
3865

Minjie Wang's avatar
Minjie Wang committed
3866
3867
3868
        .. note:: Here we use pytorch syntax for demo. The general idea applies
            to other frameworks with minor syntax change (e.g. replace
            ``torch.tensor`` with ``mxnet.ndarray``).
Da Zheng's avatar
Da Zheng committed
3869

Minjie Wang's avatar
Minjie Wang committed
3870
3871
3872
3873
3874
        >>> import torch as th
        >>> g = DGLGraph()
        >>> g.add_nodes(5, {'n1': th.randn(5, 10)})
        >>> g.add_edges([0,1,3,4], [2,4,0,3], {'e1': th.randn(4, 6)})
        >>> nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
Da Zheng's avatar
Da Zheng committed
3875

Minjie Wang's avatar
Minjie Wang committed
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
        See Also
        --------
        dgl.to_networkx
        """
        # TODO(minjie): multi-type support
        assert len(self.ntypes) == 1
        assert len(self.etypes) == 1
        src, dst = self.edges()
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
3886
3887
        # xiangsx: Always treat graph as multigraph
        nx_graph = nx.MultiDiGraph()
Minjie Wang's avatar
Minjie Wang committed
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
        nx_graph.add_nodes_from(range(self.number_of_nodes()))
        for eid, (u, v) in enumerate(zip(src, dst)):
            nx_graph.add_edge(u, v, id=eid)

        if node_attrs is not None:
            for nid, attr in nx_graph.nodes(data=True):
                feat_dict = self._get_n_repr(0, nid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
        if edge_attrs is not None:
            for _, _, attr in nx_graph.edges(data=True):
                eid = attr['id']
                feat_dict = self._get_e_repr(0, eid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
        return nx_graph

    def filter_nodes(self, predicate, nodes=ALL, ntype=None):
Da Zheng's avatar
Da Zheng committed
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
        """Return a tensor of node IDs with the given node type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(nodes) -> tensor``.
            ``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding node in
            the batch satisfies the predicate.
        nodes : int, iterable or tensor of ints
            The nodes to filter on. Default value is all the nodes.
Minjie Wang's avatar
Minjie Wang committed
3917
3918
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
3919
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3920
3921
3922
3923

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3924
3925
3926
3927
3928
3929
3930
            Node ids indicating the nodes that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
3931
        >>> g = dgl.graph([], 'user', 'follows', num_nodes=4)
Mufei Li's avatar
Mufei Li committed
3932
3933
3934
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3935
        """
3936
        check_same_dtype(self._idtype_str, nodes)
Minjie Wang's avatar
Minjie Wang committed
3937
3938
        ntid = self.get_ntype_id(ntype)
        if is_all(nodes):
3939
            v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3940
        else:
3941
            v = utils.toindex(nodes, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3942
3943

        n_repr = self._get_n_repr(ntid, v)
3944
        nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid])
Minjie Wang's avatar
Minjie Wang committed
3945
        n_mask = F.copy_to(predicate(nbatch), F.cpu())
Da Zheng's avatar
Da Zheng committed
3946

Minjie Wang's avatar
Minjie Wang committed
3947
3948
3949
3950
3951
3952
3953
        if is_all(nodes):
            return F.nonzero_1d(n_mask)
        else:
            nodes = F.tensor(nodes)
            return F.boolean_mask(nodes, n_mask)

    def filter_edges(self, predicate, edges=ALL, etype=None):
Da Zheng's avatar
Da Zheng committed
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
        """Return a tensor of edge IDs with the given edge type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(edges) -> tensor``.
            ``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding edge in
            the batch satisfies the predicate.
        edges : valid edges type
            Edges on which to apply ``func``. See :func:`send` for valid
            edges type. Default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
3968
3969
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3970
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3971
3972
3973
3974

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
            Edge ids indicating the edges that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
        >>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows')
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3986
        """
3987
        check_same_dtype(self._idtype_str, edges)
Minjie Wang's avatar
Minjie Wang committed
3988
3989
3990
3991
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
3992
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)), self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3993
3994
        elif isinstance(edges, tuple):
            u, v = edges
3995
3996
            u = utils.toindex(u, self._idtype_str)
            v = utils.toindex(v, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
3997
3998
3999
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
4000
            eid = utils.toindex(edges, self._idtype_str)
Minjie Wang's avatar
Minjie Wang committed
4001
            u, v, _ = self._graph.find_edges(etid, eid)
Da Zheng's avatar
Da Zheng committed
4002

Minjie Wang's avatar
Minjie Wang committed
4003
4004
4005
        src_data = self._get_n_repr(stid, u)
        edge_data = self._get_e_repr(etid, eid)
        dst_data = self._get_n_repr(dtid, v)
4006
4007
        ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data,
                           canonical_etype=self.canonical_etypes[etid])
Minjie Wang's avatar
Minjie Wang committed
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
        e_mask = F.copy_to(predicate(ebatch), F.cpu())

        if is_all(edges):
            return F.nonzero_1d(e_mask)
        else:
            edges = F.tensor(edges)
            return F.boolean_mask(edges, e_mask)

    def to(self, ctx):  # pylint: disable=invalid-name
        """Move both ndata and edata to the targeted mode (cpu/gpu)
        Framework agnostic
Da Zheng's avatar
Da Zheng committed
4019
4020
4021

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
4022
4023
4024
        ctx : framework-specific context object
            The context to move data to.

4025
4026
4027
4028
4029
        Returns
        -------
        g : DGLHeteroGraph
          Moved DGLHeteroGraph of the targeted mode.

Minjie Wang's avatar
Minjie Wang committed
4030
4031
4032
4033
4034
        Examples
        --------
        The following example uses PyTorch backend.

        >>> import torch
Mufei Li's avatar
Mufei Li committed
4035
4036
4037
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
4038
        >>> g = g.to(torch.device('cuda:0'))
Da Zheng's avatar
Da Zheng committed
4039
        """
Minjie Wang's avatar
Minjie Wang committed
4040
4041
4042
4043
4044
4045
        for i in range(len(self._node_frames)):
            for k in self._node_frames[i].keys():
                self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
        for i in range(len(self._edge_frames)):
            for k in self._edge_frames[i].keys():
                self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
4046
        return self
Da Zheng's avatar
Da Zheng committed
4047

Minjie Wang's avatar
Minjie Wang committed
4048
    def local_var(self):
Mufei Li's avatar
Mufei Li committed
4049
        """Return a heterograph object that can be used in a local function scope.
Minjie Wang's avatar
Minjie Wang committed
4050
4051
4052
4053
4054
4055
4056
4057

        The returned graph object shares the feature data and graph structure of this graph.
        However, any out-place mutation to the feature data will not reflect to this graph,
        thus making it easier to use in a function scope.

        If set, the local graph object will use same initializers for node features and
        edge features.

Mufei Li's avatar
Mufei Li committed
4058
4059
4060
4061
4062
4063
4064
4065
4066
4067
4068
4069
4070
        Returns
        -------
        DGLHeteroGraph
            The graph object that can be used as a local variable.

        Notes
        -----
        Internally, the returned graph shares the same feature tensors, but construct a new
        dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
        graph will not reflect to the original graph. However, inplace operations do change
        the shared tensor values, so will be reflected to the original graph. This function
        also has little overhead when the number of feature tensors in this graph is small.

Minjie Wang's avatar
Minjie Wang committed
4071
4072
4073
4074
4075
4076
4077
4078
4079
        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
4080
4081
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4082
        >>>
Mufei Li's avatar
Mufei Li committed
4083
4084
4085
4086
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
4087
4088
4089
4090
4091
4092

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
4093
4094
4095
        >>>     # This 'h' feature will stay local and be GCed when the function exits
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4096
        >>>
Mufei Li's avatar
Mufei Li committed
4097
4098
4099
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
4100
4101
4102
4103
4104
4105
        False

        See Also
        --------
        local_var
        """
4106
4107
4108
4109
4110
4111
        local_node_frames = [fr.clone() for fr in self._node_frames]
        local_edge_frames = [fr.clone() for fr in self._edge_frames]
        ret = copy.copy(self)
        ret._node_frames = local_node_frames
        ret._edge_frames = local_edge_frames
        return ret
Minjie Wang's avatar
Minjie Wang committed
4112
4113
4114
4115
4116
4117
4118
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131

    @contextmanager
    def local_scope(self):
        """Enter a local scope context for this graph.

        By entering a local scope, any out-place mutation to the feature data will
        not reflect to the original graph, thus making it easier to use in a function scope.

        If set, the local scope will use same initializers for node features and
        edge features.

        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
4132
4133
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4134
        >>>
Mufei Li's avatar
Mufei Li committed
4135
4136
4137
4138
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
4139
4140
4141
4142
4143
4144

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
4145
4146
4147
        >>>         # This 'h' feature will stay local and be GCed when the function exits
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
4148
        >>>
Mufei Li's avatar
Mufei Li committed
4149
4150
4151
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
4152
4153
4154
4155
4156
4157
4158
4159
        False

        See Also
        --------
        local_var
        """
        old_nframes = self._node_frames
        old_eframes = self._edge_frames
4160
4161
        self._node_frames = [fr.clone() for fr in self._node_frames]
        self._edge_frames = [fr.clone() for fr in self._edge_frames]
Minjie Wang's avatar
Minjie Wang committed
4162
4163
4164
4165
        yield
        self._node_frames = old_nframes
        self._edge_frames = old_eframes

4166
4167
4168
4169
    def is_homograph(self):
        """Return if the graph is homogeneous."""
        return len(self.ntypes) == 1 and len(self.etypes) == 1

4170
4171
4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
    def format_in_use(self, etype=None, return_all=False):
        """Return the sparse formats in use of the given edge/relation type.

        Returns
        -------
        list of string
            Return all the formats currently in use (could be multiple).

        See Also
        --------
        restrict_format
        to_format
        """
        return self._graph.format_in_use(self.get_etype_id(etype))

    def restrict_format(self, etype=None):
        """Return the allowed sparse formats of the given edge/relation type.

        Returns
        -------
        string : 'any', 'coo', 'csr', or 'csc'
            'any' indicates all sparse formats are allowed in .

        See Also
        --------
        format_in_use
        to_format
        """
        return self._graph.restrict_format(self.get_etype_id(etype))

    def to_format(self, restrict_format):
        """Return a cloned graph but stored in the given restrict format.

        If 'any' is given, the restrict formats of the returned graph is relaxed.
        The returned graph share the same node/edge data of the original graph.

        Parameters
        ----------
        restrict_format : string
            Desired restrict format ('any', 'coo', 'csr', 'csc').

        Returns
        -------
        A new graph.

        See Also
        --------
        format_in_use
        restrict_format
        """
        return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
                              self._node_frames,
                              self._edge_frames)

4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
    def long(self):
        """Return a heterograph object use int64 as index dtype,
        with the ndata and edata as the original object

        Returns
        -------
        DGLHeteroGraph
            The graph object

        Examples
        --------

        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
        >>>                   index_dtype='int32')
        >>> g_long = g.long() # Convert g to int64 indexed, not changing the original `g`

        See Also
        --------
        int
4243
        idtype
4244
4245
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
        """
        return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes,
                              self._node_frames,
                              self._edge_frames)

    def int(self):
        """Return a heterograph object use int32 as index dtype,
        with the ndata and edata as the original object

        Returns
        -------
        DGLHeteroGraph
            The graph object

        Examples
        --------

        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
        >>>                   index_dtype='int64')
        >>> g_int = g.int() # Convert g to int32 indexed, not changing the original `g`

        See Also
        --------
        long
4268
        idtype
4269
4270
4271
4272
4273
        """
        return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes,
                              self._node_frames,
                              self._edge_frames)

Minjie Wang's avatar
Minjie Wang committed
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
############################################################
# Internal APIs
############################################################

def make_canonical_etypes(etypes, ntypes, metagraph):
    """Internal function to convert etype name to (srctype, etype, dsttype)

    Parameters
    ----------
    etypes : list of str
        Edge type list
    ntypes : list of str
        Node type list
    metagraph : GraphIndex
        Meta graph.

    Returns
    -------
    list of tuples (srctype, etype, dsttype)
    """
    # sanity check
    if len(etypes) != metagraph.number_of_edges():
        raise DGLError('Length of edge type list must match the number of '
                       'edges in the metagraph. {} vs {}'.format(
                           len(etypes), metagraph.number_of_edges()))
    if len(ntypes) != metagraph.number_of_nodes():
        raise DGLError('Length of nodes type list must match the number of '
                       'nodes in the metagraph. {} vs {}'.format(
                           len(ntypes), metagraph.number_of_nodes()))
    src, dst, eid = metagraph.edges()
    rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
    return rst

4307
4308
4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
def is_unibipartite(graph):
    """Internal function that returns whether the given graph is a uni-directional
    bipartite graph.

    Parameters
    ----------
    graph : GraphIndex
        Input graph

    Returns
    -------
    bool
        True if the graph is a uni-bipartite.
    """
    src, dst, _ = graph.edges()
    return set(src.tonumpy()).isdisjoint(set(dst.tonumpy()))

def find_src_dst_ntypes(ntypes, metagraph):
    """Internal function to split ntypes into SRC and DST categories.

    If the metagraph is not a uni-bipartite graph (so that the SRC and DST categories
    are not well-defined), return None.

    For node types that are isolated (i.e, no relation is associated with it), they
    are assigned to the SRC category.

    Parameters
    ----------
    ntypes : list of str
        Node type list
    metagraph : GraphIndex
        Meta graph.

    Returns
    -------
    (dict[int, str], dict[int, str]) or None
        Node types belonging to SRC and DST categories. Types are stored in
        a dictionary from type name to type id. Return None if the graph is
        not uni-bipartite.
    """
    src, dst, _ = metagraph.edges()
    if set(src.tonumpy()).isdisjoint(set(dst.tonumpy())):
        srctypes = {ntypes[tid] : tid for tid in src}
        dsttypes = {ntypes[tid] : tid for tid in dst}
        # handle isolated node types
        for ntid, ntype in enumerate(ntypes):
            if ntype not in srctypes and ntype not in dsttypes:
                srctypes[ntype] = ntid
        return srctypes, dsttypes
    else:
        return None

Minjie Wang's avatar
Minjie Wang committed
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
4382
4383
4384
4385
4386
4387
4388
4389
4390
4391
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
4402
def infer_ntype_from_dict(graph, etype_dict):
    """Infer node type from dictionary of edge type to values.

    All the edge types in the dict must share the same destination node type
    and the node type will be returned. Otherwise, throw error.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    etype_dict : dict
        Dictionary whose key is edge type

    Returns
    -------
    str
        Node type
    """
    ntype = None
    for ety in etype_dict:
        _, _, dty = graph.to_canonical_etype(ety)
        if ntype is None:
            ntype = dty
        if ntype != dty:
            raise DGLError("Cannot infer destination node type from the dictionary. "
                           "A valid specification must make sure that all the edge "
                           "type keys share the same destination node type.")
    return ntype

def pad_tuple(tup, length, pad_val=None):
    """Pad the given tuple to the given length.

    If the input is not a tuple, convert it to a tuple of length one.
    Return None if pad fails.
    """
    if not isinstance(tup, tuple):
        tup = (tup, )
    if len(tup) > length:
        return None
    elif len(tup) == length:
        return tup
    else:
        return tup + (pad_val,) * (length - len(tup))

4403
def merge_frames(frames, reducer, order=None):
Minjie Wang's avatar
Minjie Wang committed
4404
4405
4406
4407
    """Merge input frames into one. Resolve conflict fields using reducer.

    Parameters
    ----------
4408
    frames : list[FrameRef]
Minjie Wang's avatar
Minjie Wang committed
4409
4410
4411
        Input frames
    reducer : str
        One of "sum", "max", "min", "mean", "stack"
4412
4413
4414
4415
4416
4417
    order : list[Int], optional
        Merge order hint. Useful for "stack" reducer.
        If provided, each integer indicates the relative order
        of the ``frames`` list. Frames are sorted according to this list
        in ascending order. Tie is not handled so make sure the order values
        are distinct.
Minjie Wang's avatar
Minjie Wang committed
4418
4419
4420
4421
4422

    Returns
    -------
    FrameRef
        Merged frame
Da Zheng's avatar
Da Zheng committed
4423
    """
4424
4425
4426
    if len(frames) == 1 and reducer != 'stack':
        # Directly return the only one input. Stack reducer requires
        # modifying tensor shape.
Minjie Wang's avatar
Minjie Wang committed
4427
4428
        return frames[0]
    if reducer == 'stack':
4429
4430
4431
4432
4433
        # Stack order does not matter. However, it must be consistent!
        if order:
            assert len(order) == len(frames)
            sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
            frames = list(zip(*sorted_with_key))[0]
Minjie Wang's avatar
Minjie Wang committed
4434
4435
4436
4437
4438
4439
4440
4441
        def merger(flist):
            return F.stack(flist, 1)
    else:
        redfn = getattr(F, reducer, None)
        if redfn is None:
            raise DGLError('Invalid cross type reducer. Must be one of '
                           '"sum", "max", "min", "mean" or "stack".')
        def merger(flist):
4442
            return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
Minjie Wang's avatar
Minjie Wang committed
4443
4444
4445
4446
4447
4448
4449
4450
4451
    ret = FrameRef(frame_like(frames[0]._frame))
    keys = set()
    for frm in frames:
        keys.update(frm.keys())
    for k in keys:
        flist = []
        for frm in frames:
            if k in frm:
                flist.append(frm[k])
4452
        ret[k] = merger(flist)
Minjie Wang's avatar
Minjie Wang committed
4453
4454
4455
4456
4457
4458
4459
    return ret

def combine_frames(frames, ids):
    """Merge the frames into one frame, taking the common columns.

    Return None if there is no common columns.

Da Zheng's avatar
Da Zheng committed
4460
4461
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
4462
4463
4464
4465
4466
4467
4468
4469
4470
    frames : List[FrameRef]
        List of frames
    ids : List[int]
        List of frame IDs

    Returns
    -------
    FrameRef
        The resulting frame
Da Zheng's avatar
Da Zheng committed
4471
    """
Minjie Wang's avatar
Minjie Wang committed
4472
4473
4474
4475
4476
4477
4478
4479
4480
4481
4482
4483
4484
4485
4486
4487
4488
4489
4490
4491
4492
4493
4494
4495
4496
4497
4498
4499
4500
4501
4502
4503
4504
4505
4506
4507
4508
4509
4510
4511
4512
4513
4514
4515
4516
4517
4518
4519
4520
4521
4522
4523
4524
4525
4526
4527
4528
4529
4530
    # find common columns and check if their schemes match
    schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
    for frame_id in ids:
        frame = frames[frame_id]
        for key, scheme in list(schemes.items()):
            if key in frame.schemes:
                if frame.schemes[key] != scheme:
                    raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
                                   (key, frame.schemes[key], scheme))
            else:
                del schemes[key]

    if len(schemes) == 0:
        return None

    # concatenate the columns
    to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
    cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
    return FrameRef(Frame(cols))

def combine_names(names, ids=None):
    """Combine the selected names into one new name.

    Parameters
    ----------
    names : list of str
        String names
    ids : numpy.ndarray, optional
        Selected index

    Returns
    -------
    str
    """
    if ids is None:
        return '+'.join(sorted(names))
    else:
        selected = sorted([names[i] for i in ids])
        return '+'.join(selected)

class AdaptedHeteroGraph(GraphAdapter):
    """Adapt DGLGraph to interface required by scheduler.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    stid : int
        Source node type id
    dtid : int
        Destination node type id
    etid : int
        Edge type id
    """
    def __init__(self, graph, stid, dtid, etid):
        self.graph = graph
        self.stid = stid
        self.dtid = dtid
        self.etid = etid
Da Zheng's avatar
Da Zheng committed
4531
4532

    @property
Minjie Wang's avatar
Minjie Wang committed
4533
4534
    def gidx(self):
        return self.graph._graph
Da Zheng's avatar
Da Zheng committed
4535

Minjie Wang's avatar
Minjie Wang committed
4536
4537
4538
    def num_src(self):
        """Number of source nodes."""
        return self.graph._graph.number_of_nodes(self.stid)
Da Zheng's avatar
Da Zheng committed
4539

Minjie Wang's avatar
Minjie Wang committed
4540
4541
4542
4543
4544
4545
4546
    def num_dst(self):
        """Number of destination nodes."""
        return self.graph._graph.number_of_nodes(self.dtid)

    def num_edges(self):
        """Number of edges."""
        return self.graph._graph.number_of_edges(self.etid)
Da Zheng's avatar
Da Zheng committed
4547
4548

    @property
Minjie Wang's avatar
Minjie Wang committed
4549
4550
4551
    def srcframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.stid]
Da Zheng's avatar
Da Zheng committed
4552

Minjie Wang's avatar
Minjie Wang committed
4553
4554
4555
4556
    @property
    def dstframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.dtid]
Da Zheng's avatar
Da Zheng committed
4557

Minjie Wang's avatar
Minjie Wang committed
4558
4559
4560
4561
    @property
    def edgeframe(self):
        """Frame to store edge features."""
        return self.graph._edge_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
4562

Minjie Wang's avatar
Minjie Wang committed
4563
4564
4565
4566
    @property
    def msgframe(self):
        """Frame to store messages."""
        return self.graph._msg_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
4567

Minjie Wang's avatar
Minjie Wang committed
4568
4569
4570
4571
    @property
    def msgindicator(self):
        """Message indicator tensor."""
        return self.graph._get_msg_index(self.etid)
Da Zheng's avatar
Da Zheng committed
4572

Minjie Wang's avatar
Minjie Wang committed
4573
4574
4575
4576
    @msgindicator.setter
    def msgindicator(self, val):
        """Set new message indicator tensor."""
        self.graph._set_msg_index(self.etid, val)
Da Zheng's avatar
Da Zheng committed
4577

Minjie Wang's avatar
Minjie Wang committed
4578
4579
    def in_edges(self, nodes):
        return self.graph._graph.in_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
4580

Minjie Wang's avatar
Minjie Wang committed
4581
4582
    def out_edges(self, nodes):
        return self.graph._graph.out_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
4583

Minjie Wang's avatar
Minjie Wang committed
4584
4585
    def edges(self, form):
        return self.graph._graph.edges(self.etid, form)
Da Zheng's avatar
Da Zheng committed
4586

Minjie Wang's avatar
Minjie Wang committed
4587
4588
4589
4590
4591
    def get_immutable_gidx(self, ctx):
        return self.graph._graph.get_unitgraph(self.etid, ctx)

    def bits_needed(self):
        return self.graph._graph.bits_needed(self.etid)
4592
4593
4594
4595
4596

    @property
    def canonical_etype(self):
        """Canonical edge type."""
        return self.graph.canonical_etypes[self.etid]
4597
4598
4599
4600
4601
4602
4603
4604
4605
4606
4607
4608
4609
4610
4611


def check_same_dtype(graph_dtype, tensor):
    """check whether tensor's dtype is consistent with graph's dtype"""
    if F.is_tensor(tensor):
        if graph_dtype != F.reverse_data_type_dict[F.dtype(tensor)]:
            raise utils.InconsistentDtypeException(
                "Expect the input tensor to be the same as the graph index dtype({}), but got {}"
                .format(graph_dtype, F.reverse_data_type_dict[F.dtype(tensor)]))


def check_idtype_dict(graph_dtype, tensor_dict):
    """check whether the dtypes of tensors in dict are consistent with graph's dtype"""
    for _, v in tensor_dict.items():
        check_same_dtype(graph_dtype, v)