heterograph.py 142 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
"""Classes for heterogeneous graphs."""
2
from collections import defaultdict
Minjie Wang's avatar
Minjie Wang committed
3
from contextlib import contextmanager
4
import networkx as nx
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
import numpy as np

from . import graph_index
from . import heterograph_index
9
10
11
from . import utils
from . import backend as F
from . import init
Minjie Wang's avatar
Minjie Wang committed
12
13
from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer
14
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
Mufei Li's avatar
Mufei Li committed
15
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
Mufei Li's avatar
Mufei Li committed
16
from .udf import NodeBatch, EdgeBatch
Minjie Wang's avatar
Minjie Wang committed
17
18
19
20
21
22

__all__ = ['DGLHeteroGraph', 'combine_names']

class DGLHeteroGraph(object):
    """Base heterogeneous graph class.

Mufei Li's avatar
Mufei Li committed
23
24
    **Do NOT instantiate from this class directly; use** :mod:`conversion methods
    <dgl.convert>` **instead.**
Minjie Wang's avatar
Minjie Wang committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    A Heterogeneous graph is defined as a graph with node types and edge
    types.

    If two edges share the same edge type, then their source nodes, as well
    as their destination nodes, also have the same type (the source node
    types don't have to be the same as the destination node types).

    Examples
    --------
    Suppose that we want to construct the following heterogeneous graph:

    .. graphviz::

       digraph G {
           Alice -> Bob [label=follows]
           Bob -> Carol [label=follows]
           Alice -> Tetris [label=plays]
           Bob -> Tetris [label=plays]
           Bob -> Minecraft [label=plays]
           Carol -> Minecraft [label=plays]
           Nintendo -> Tetris [label=develops]
           Mojang -> Minecraft [label=develops]
           {rank=source; Alice; Bob; Carol}
           {rank=sink; Nintendo; Mojang}
       }

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
52
    And suppose that one maps the users, games and developers to the following
Minjie Wang's avatar
Minjie Wang committed
53
54
    IDs:

Mufei Li's avatar
Mufei Li committed
55
56
57
58
59
    =========  =====  ===  =====
    User name  Alice  Bob  Carol
    =========  =====  ===  =====
    User ID    0      1    2
    =========  =====  ===  =====
Minjie Wang's avatar
Minjie Wang committed
60

Mufei Li's avatar
Mufei Li committed
61
62
63
64
65
    =========  ======  =========
    Game name  Tetris  Minecraft
    =========  ======  =========
    Game ID    0       1
    =========  ======  =========
Minjie Wang's avatar
Minjie Wang committed
66

Mufei Li's avatar
Mufei Li committed
67
68
69
70
71
    ==============  ========  ======
    Developer name  Nintendo  Mojang
    ==============  ========  ======
    Developer ID    0         1
    ==============  ========  ======
Minjie Wang's avatar
Minjie Wang committed
72
73
74
75
76

    One can construct the graph as follows:

    >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
Mufei Li's avatar
Mufei Li committed
77
78
    >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    >>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
Minjie Wang's avatar
Minjie Wang committed
79

80
81
82
83
84
85
86
87
    Or equivalently

    >>> g = dgl.heterograph({
    ...     ('user', 'follows', 'user'): [(0, 1), (1, 2)],
    ...     ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
    ...     ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
    ...     })

Minjie Wang's avatar
Minjie Wang committed
88
    :func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
Mufei Li's avatar
Mufei Li committed
89
90
91
92
93
94
95
96
    data types including:

    * edge list
    * edge tuples
    * networkx graph
    * scipy sparse matrix

    Click the function names for more details.
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103
104
105
106
107
108
109

    Then one can query the graph structure by specifying the ``ntype`` or ``etype`` arguments:

    >>> g.number_of_nodes('user')
    3
    >>> g.number_of_edges('plays')
    4
    >>> g.out_degrees(etype='develops')  # out-degrees of source nodes of 'develops' relation
    tensor([1, 1])
    >>> g.in_edges(0, etype='develops')  # in-edges of destination node 0 of 'develops' relation
    (tensor([0]), tensor([0]))

    Or on the sliced graph for an edge type:
110

Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
115
    >>> g['plays'].number_of_edges()
    4
    >>> g['develops'].out_degrees()
    tensor([1, 1])
    >>> g['develops'].in_edges(0)
Mufei Li's avatar
Mufei Li committed
116
    (tensor([0]), tensor([0]))
Minjie Wang's avatar
Minjie Wang committed
117
118
119
120
121
122
123
124
125
126

    Node type names must be distinct (no two types have the same name). Edge types could
    have the same name but they must be distinguishable by the ``(src_type, edge_type, dst_type)``
    triplet (called *canonical edge type*).

    For example, suppose a graph that has two types of relation "user-watches-movie"
    and "user-watches-TV" as follows:

    >>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
    >>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
Mufei Li's avatar
Mufei Li committed
127
    >>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs
Minjie Wang's avatar
Minjie Wang committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    To distinguish between the two "watches" edge type, one must specify a full triplet:

    >>> GG.number_of_edges(('user', 'watches', 'movie'))
    3
    >>> GG.number_of_edges(('user', 'watches', 'TV'))
    2
    >>> GG['user', 'watches', 'movie'].out_degrees()
    tensor([1, 2])

    Using only one single edge type string "watches" is ambiguous and will cause error:

    >>> GG.number_of_edges('watches')  # AMBIGUOUS!!

    In many cases, there is only one type of nodes or one type of edges, and the ``ntype``
    and ``etype`` argument could be omitted. This is very common when using the sliced
    graph, which usually contains only one edge type, and sometimes only one node type:
145

Minjie Wang's avatar
Minjie Wang committed
146
147
148
149
    >>> g['follows'].number_of_nodes()  # OK!! because g['follows'] only has one node type 'user'
    3
    >>> g['plays'].number_of_nodes()  # ERROR!! There are two types 'user' and 'game'.
    >>> g['plays'].number_of_edges()  # OK!! because there is only one edge type 'plays'
Da Zheng's avatar
Da Zheng committed
150

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    Metagraph
    ---------
    For each heterogeneous graph, one can often infer the *metagraph*, the template of
    edge connections showing how many types of nodes and edges exist in the graph, and
    how each edge type could connect between node types.

    One can analyze the example gameplay graph above and figure out the metagraph as
    follows:

    .. graphviz::

       digraph G {
           User -> User [label=follows]
           User -> Game [label=plays]
           Developer -> Game [label=develops]
       }


Da Zheng's avatar
Da Zheng committed
169
170
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
171
172
173
    gidx : HeteroGraphIndex
        Graph index object.
    ntypes : list of str
Mufei Li's avatar
Mufei Li committed
174
        Node type list. ``ntypes[i]`` stores the name of node type i.
Minjie Wang's avatar
Minjie Wang committed
175
    etypes : list of str
Mufei Li's avatar
Mufei Li committed
176
        Edge type list. ``etypes[i]`` stores the name of edge type i.
Minjie Wang's avatar
Minjie Wang committed
177
    node_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
178
179
180
        Node feature storage. If None, empty frame is created.
        Otherwise, ``node_frames[i]`` stores the node features
        of node type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
181
    edge_frames : list of FrameRef, optional
Mufei Li's avatar
Mufei Li committed
182
183
184
        Edge feature storage. If None, empty frame is created.
        Otherwise, ``edge_frames[i]`` stores the edge features
        of edge type i. (default: None)
Minjie Wang's avatar
Minjie Wang committed
185
    multigraph : bool, optional
Mufei Li's avatar
Mufei Li committed
186
187
        Whether the graph would be a multigraph. If none, the flag will be
        determined by scanning the whole graph. (default: None)
Minjie Wang's avatar
Minjie Wang committed
188
    readonly : bool, optional
Mufei Li's avatar
Mufei Li committed
189
190
        Whether the graph structure is read-only. Currently, only readonly
        is allowed. (default: True).
Minjie Wang's avatar
Minjie Wang committed
191
    """
Da Zheng's avatar
Da Zheng committed
192
    # pylint: disable=unused-argument
Minjie Wang's avatar
Minjie Wang committed
193
194
195
196
197
198
199
200
201
    def __init__(self,
                 gidx,
                 ntypes,
                 etypes,
                 node_frames=None,
                 edge_frames=None,
                 multigraph=None,
                 readonly=True):
        assert readonly, "Only readonly heterogeneous graphs are supported"
Da Zheng's avatar
Da Zheng committed
202

Minjie Wang's avatar
Minjie Wang committed
203
204
        self._graph = gidx
        self._nx_metagraph = None
205
206
        self._ntypes = ntypes
        self._etypes = etypes
Minjie Wang's avatar
Minjie Wang committed
207
208
209
210
211
212
213
214
215
216
217
        self._canonical_etypes = make_canonical_etypes(etypes, ntypes, self._graph.metagraph)
        # An internal map from etype to canonical etype tuple.
        # If two etypes have the same name, an empty tuple is stored instead to indicte ambiguity.
        self._etype2canonical = {}
        for i, ety in enumerate(etypes):
            if ety in self._etype2canonical:
                self._etype2canonical[ety] = tuple()
            else:
                self._etype2canonical[ety] = self._canonical_etypes[i]
        self._ntypes_invmap = {t : i for i, t in enumerate(ntypes)}
        self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
Da Zheng's avatar
Da Zheng committed
218

Minjie Wang's avatar
Minjie Wang committed
219
220
221
222
223
224
225
        # node and edge frame
        if node_frames is None:
            node_frames = [None] * len(self._ntypes)
        node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
                       if frame is None else frame
                       for i, frame in enumerate(node_frames)]
        self._node_frames = node_frames
Da Zheng's avatar
Da Zheng committed
226

Minjie Wang's avatar
Minjie Wang committed
227
228
229
230
231
232
        if edge_frames is None:
            edge_frames = [None] * len(self._etypes)
        edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
                       if frame is None else frame
                       for i, frame in enumerate(edge_frames)]
        self._edge_frames = edge_frames
Da Zheng's avatar
Da Zheng committed
233

Minjie Wang's avatar
Minjie Wang committed
234
235
236
237
238
239
240
        # message indicators
        self._msg_indices = [None] * len(self._etypes)
        self._msg_frames = []
        for i in range(len(self._etypes)):
            frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
            frame.set_initializer(init.zero_initializer)
            self._msg_frames.append(frame)
Da Zheng's avatar
Da Zheng committed
241

Mufei Li's avatar
Mufei Li committed
242
243
        self._is_multigraph = multigraph

Minjie Wang's avatar
Minjie Wang committed
244
245
246
247
248
    def _get_msg_index(self, etid):
        if self._msg_indices[etid] is None:
            self._msg_indices[etid] = utils.zero_index(
                size=self._graph.number_of_edges(etid))
        return self._msg_indices[etid]
Da Zheng's avatar
Da Zheng committed
249

Minjie Wang's avatar
Minjie Wang committed
250
251
    def _set_msg_index(self, etid, index):
        self._msg_indices[etid] = index
Da Zheng's avatar
Da Zheng committed
252

Minjie Wang's avatar
Minjie Wang committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def __repr__(self):
        if len(self.ntypes) == 1 and len(self.etypes) == 1:
            ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
                   '      ndata_schemes={ndata}\n'
                   '      edata_schemes={edata})')
            return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
                              ndata=str(self.node_attr_schemes()),
                              edata=str(self.edge_attr_schemes()))
        else:
            ret = ('Graph(num_nodes={node},\n'
                   '      num_edges={edge},\n'
                   '      metagraph={meta})')
            nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
                          for i in range(len(self.ntypes))}
267
            nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
Minjie Wang's avatar
Minjie Wang committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                          for i in range(len(self.etypes))}
            meta = str(self.metagraph.edges())
            return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)

    #################################################################
    # Mutation operations
    #################################################################

    def add_nodes(self, num, data=None, ntype=None):
        """Add multiple new nodes of the same node type

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edge(self, u, v, data=None, etype=None):
        """Add an edge of ``etype`` between u of the source node type, and v
        of the destination node type..

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    def add_edges(self, u, v, data=None, etype=None):
        """Add multiple edges of ``etype`` between list of source nodes ``u``
        and list of destination nodes ``v`` of type ``vtype``.  A single edge
        is added between every pair of ``u[i]`` and ``v[i]``.

        Currently not supported.
        """
        raise DGLError('Mutation is not supported in heterograph.')

    #################################################################
    # Metagraph query
    #################################################################
Da Zheng's avatar
Da Zheng committed
303

304
    @property
Minjie Wang's avatar
Minjie Wang committed
305
    def ntypes(self):
Mufei Li's avatar
Mufei Li committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.ntypes
        ['user', 'game']
        """
321
        return self._ntypes
Da Zheng's avatar
Da Zheng committed
322

323
    @property
Minjie Wang's avatar
Minjie Wang committed
324
    def etypes(self):
Mufei Li's avatar
Mufei Li committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.etypes
        ['follows', 'plays']
        """
340
        return self._etypes
Da Zheng's avatar
Da Zheng committed
341

Minjie Wang's avatar
Minjie Wang committed
342
343
344
345
346
    @property
    def canonical_etypes(self):
        """Return the list of canonical edge types of this graph.

        A canonical edge type is a tuple of string (src_type, edge_type, dst_type).
Mufei Li's avatar
Mufei Li committed
347
348
349
350
351
352
353
354
355
356
357
358
359

        Returns
        -------
        list of 3-tuples

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.canonical_etypes
        [('user', 'follows', 'user'), ('user', 'plays', 'game')]
Minjie Wang's avatar
Minjie Wang committed
360
361
362
        """
        return self._canonical_etypes

Da Zheng's avatar
Da Zheng committed
363
364
    @property
    def metagraph(self):
365
366
367
368
        """Return the metagraph as networkx.MultiDiGraph.

        The nodes are labeled with node type names.
        The edges have their keys holding the edge type names.
Minjie Wang's avatar
Minjie Wang committed
369
370
371
372

        Returns
        -------
        networkx.MultiDiGraph
Mufei Li's avatar
Mufei Li committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        Examples
        --------

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> meta_g = g.metagraph

        The metagraph then has two nodes and two edges.

        >>> meta_g.nodes()
        NodeView(('user', 'game'))
        >>> meta_g.number_of_nodes()
        2
        >>> meta_g.edges()
        OutMultiEdgeDataView([('user', 'user'), ('user', 'game')])
        >>> meta_g.number_of_edges()
        2
Minjie Wang's avatar
Minjie Wang committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        """
        if self._nx_metagraph is None:
            nx_graph = self._graph.metagraph.to_networkx()
            self._nx_metagraph = nx.MultiDiGraph()
            for u_v in nx_graph.edges:
                srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
                self._nx_metagraph.add_edge(srctype, dsttype, etype)
        return self._nx_metagraph

    def to_canonical_etype(self, etype):
        """Convert edge type to canonical etype: (srctype, etype, dsttype).

        The input can already be a canonical tuple.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type

        Returns
        -------
        tuple of str
Mufei Li's avatar
Mufei Li committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        Examples
        --------

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g3 = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'follows', 'game')
        >>> g = dgl.hetero_from_relations([g1, g2, g3])

        Get canonical edge types.

        >>> g.to_canonical_etype('plays')
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype(('user', 'plays', 'game'))
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype('follows')
        DGLError: Edge type "follows" is ambiguous.
        Please use canonical etype type in the form of (srctype, etype, dsttype)
434
        """
Minjie Wang's avatar
Minjie Wang committed
435
436
        if isinstance(etype, tuple):
            return etype
437
        else:
Minjie Wang's avatar
Minjie Wang committed
438
439
440
441
442
443
444
445
446
447
448
449
450
            ret = self._etype2canonical.get(etype, None)
            if ret is None:
                raise DGLError('Edge type "{}" does not exist.'.format(etype))
            if len(ret) == 0:
                raise DGLError('Edge type "%s" is ambiguous. Please use canonical etype '
                               'type in the form of (srctype, etype, dsttype)' % etype)
            return ret

    def get_ntype_id(self, ntype):
        """Return the id of the given node type.

        ntype can also be None. If so, there should be only one node type in the
        graph.
451

Minjie Wang's avatar
Minjie Wang committed
452
453
454
455
        Parameters
        ----------
        ntype : str
            Node type
Da Zheng's avatar
Da Zheng committed
456
457
458

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
459
460
461
462
463
464
465
466
467
468
469
        int
        """
        if ntype is None:
            if self._graph.number_of_ntypes() != 1:
                raise DGLError('Node type name must be specified if there are more than one '
                               'node types.')
            return 0
        ntid = self._ntypes_invmap.get(ntype, None)
        if ntid is None:
            raise DGLError('Node type "{}" does not exist.'.format(ntype))
        return ntid
Da Zheng's avatar
Da Zheng committed
470

Minjie Wang's avatar
Minjie Wang committed
471
472
    def get_etype_id(self, etype):
        """Return the id of the given edge type.
473

Minjie Wang's avatar
Minjie Wang committed
474
475
476
477
478
479
480
        etype can also be None. If so, there should be only one edge type in the
        graph.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type
Da Zheng's avatar
Da Zheng committed
481

482
483
        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
484
485
486
487
488
489
490
491
492
493
494
        int
        """
        if etype is None:
            if self._graph.number_of_etypes() != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            return 0
        etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
        if etid is None:
            raise DGLError('Edge type "{}" does not exist.'.format(etype))
        return etid
Da Zheng's avatar
Da Zheng committed
495

Minjie Wang's avatar
Minjie Wang committed
496
497
498
    #################################################################
    # View
    #################################################################
Da Zheng's avatar
Da Zheng committed
499

500
    @property
Minjie Wang's avatar
Minjie Wang committed
501
    def nodes(self):
Mufei Li's avatar
Mufei Li committed
502
503
        """Return a node view that can be used to set/get feature
        data of a single node type.
Da Zheng's avatar
Da Zheng committed
504

Minjie Wang's avatar
Minjie Wang committed
505
506
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
507
508
509
510
511
        The following example uses PyTorch backend.

        To set features of all users

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
512
        >>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
Mufei Li's avatar
Mufei Li committed
513
514
515
516

        See Also
        --------
        ndata
517
        """
Minjie Wang's avatar
Minjie Wang committed
518
        return HeteroNodeView(self)
Da Zheng's avatar
Da Zheng committed
519

520
    @property
Minjie Wang's avatar
Minjie Wang committed
521
522
    def ndata(self):
        """Return the data view of all the nodes.
Da Zheng's avatar
Da Zheng committed
523

Mufei Li's avatar
Mufei Li committed
524
        **Only works if the graph has one node type.**
Minjie Wang's avatar
Minjie Wang committed
525
526
527

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
528
529
530
531
532
533
534
535
536
537
538
        The following example uses PyTorch backend.

        To set features of all nodes in a heterogeneous graph
        with only one node type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.ndata['h'] = torch.zeros(3, 5)

        See Also
        --------
        nodes
Da Zheng's avatar
Da Zheng committed
539
        """
Minjie Wang's avatar
Minjie Wang committed
540
        return HeteroNodeDataView(self, None, ALL)
Da Zheng's avatar
Da Zheng committed
541

542
    @property
Minjie Wang's avatar
Minjie Wang committed
543
    def edges(self):
Mufei Li's avatar
Mufei Li committed
544
545
        """Return an edge view that can be used to set/get feature
        data of a single edge type.
546

Minjie Wang's avatar
Minjie Wang committed
547
548
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
549
550
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
551
        To set features of all "play" relationships:
Mufei Li's avatar
Mufei Li committed
552
553
554
555
556
557
558

        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edges['plays'].data['h'] = torch.zeros(3, 4)

        See Also
        --------
        edata
559
        """
Minjie Wang's avatar
Minjie Wang committed
560
        return HeteroEdgeView(self)
561
562

    @property
Minjie Wang's avatar
Minjie Wang committed
563
564
    def edata(self):
        """Return the data view of all the edges.
565

Mufei Li's avatar
Mufei Li committed
566
        **Only works if the graph has one edge type.**
Minjie Wang's avatar
Minjie Wang committed
567
568
569

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
570
571
572
573
574
575
        The following example uses PyTorch backend.

        To set features of all edges in a heterogeneous graph
        with only one edge type:

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
576
        >>> g.edata['h'] = torch.zeros(2, 5)
Mufei Li's avatar
Mufei Li committed
577
578
579
580

        See Also
        --------
        edges
581
        """
Minjie Wang's avatar
Minjie Wang committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        return HeteroEdgeDataView(self, None, ALL)

    def _find_etypes(self, key):
        etypes = [
            i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
            (key[0] == SLICE_FULL or key[0] == srctype) and
            (key[1] == SLICE_FULL or key[1] == etype) and
            (key[2] == SLICE_FULL or key[2] == dsttype)]
        return etypes

    def __getitem__(self, key):
        """Return the relation slice of this graph.

        A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
        ``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
        slice (``:``) representing wildcard (i.e. any source/edge/destination type).

        A relation slice is a homogeneous (with one node type and one edge type) or
        bipartite (with two node types and one edge type) graph, transformed from
        the original heterogeneous graph.

        If there is only one canonical edge type found, then the returned relation
        slice would be a subgraph induced from the original graph.  That is, it is
        equivalent to ``self.edge_type_subgraph(etype)``.  The node and edge features
        of the returned graph would be shared with thew original graph.

        If there are multiple canonical edge type found, then the source/edge/destination
        node types would be a *concatenation* of original node/edge types.  The
        new source/destination node type would have the concatenation determined by
        :func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
        types as its name.  The source/destination node would be formed by concatenating the
        common features of the original source/destination types, therefore they are not
        shared with the original graph.  Edge type is similar.
        """
        err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
                  "to get view of one relation type. Use : to slice multiple types (e.g. " +\
                  "G['srctype', :, 'dsttype'])."

        if not isinstance(key, tuple):
            key = (SLICE_FULL, key, SLICE_FULL)

        if len(key) != 3:
            raise DGLError(err_msg)

        etypes = self._find_etypes(key)
        if len(etypes) == 1:
            # no ambiguity: return the unitgraph itself
            srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
            stid = self.get_ntype_id(srctype)
            etid = self.get_etype_id((srctype, etype, dsttype))
            dtid = self.get_ntype_id(dsttype)
            new_g = self._graph.get_relation_graph(etid)

            if stid == dtid:
                new_ntypes = [srctype]
                new_nframes = [self._node_frames[stid]]
            else:
                new_ntypes = [srctype, dsttype]
                new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
            new_etypes = [etype]
            new_eframes = [self._edge_frames[etid]]
643

Minjie Wang's avatar
Minjie Wang committed
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
            return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
        else:
            flat = self._graph.flatten_relations(etypes)
            new_g = flat.graph

            # merge frames
            stids = flat.induced_srctype_set.asnumpy()
            dtids = flat.induced_dsttype_set.asnumpy()
            etids = flat.induced_etype_set.asnumpy()
            new_ntypes = [combine_names(self.ntypes, stids)]
            if new_g.number_of_ntypes() == 2:
                new_ntypes.append(combine_names(self.ntypes, dtids))
                new_nframes = [
                    combine_frames(self._node_frames, stids),
                    combine_frames(self._node_frames, dtids)]
            else:
                assert np.array_equal(stids, dtids)
                new_nframes = [combine_frames(self._node_frames, stids)]
            new_etypes = [combine_names(self.etypes, etids)]
            new_eframes = [combine_frames(self._edge_frames, etids)]

            # create new heterograph
            new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)

            src = new_ntypes[0]
            dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
            # put the parent node/edge type and IDs
            new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
            new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
            new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
            new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
            new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
            new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)

            return new_hg

    #################################################################
    # Graph query
    #################################################################

    def number_of_nodes(self, ntype=None):
685
        """Return the number of nodes of the given type in the heterograph.
Da Zheng's avatar
Da Zheng committed
686
687
688

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
689
690
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
691
            in the graph. (Default: None)
692
693
694
695
696

        Returns
        -------
        int
            The number of nodes
Da Zheng's avatar
Da Zheng committed
697
698
699

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
700
701
702
703
704

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_nodes('user')
        3
        >>> g.number_of_nodes()
705
        3
Da Zheng's avatar
Da Zheng committed
706
        """
Minjie Wang's avatar
Minjie Wang committed
707
        return self._graph.number_of_nodes(self.get_ntype_id(ntype))
708

Minjie Wang's avatar
Minjie Wang committed
709
    def number_of_edges(self, etype=None):
710
711
712
713
        """Return the number of edges of the given type in the heterograph.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
714
715
716
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
717
718
719
720
721
722

        Returns
        -------
        int
            The number of edges

723
724
        Examples
        --------
Mufei Li's avatar
Mufei Li committed
725
726
727
728
729
730
731

        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.number_of_edges(('user', 'follows', 'user'))
        >>> g.number_of_edges('follows')
        2
        >>> g.number_of_edges()
        2
732
        """
Minjie Wang's avatar
Minjie Wang committed
733
734
735
736
        return self._graph.number_of_edges(self.get_etype_id(etype))

    @property
    def is_multigraph(self):
Mufei Li's avatar
Mufei Li committed
737
738
739
740
741
742
743
744
745
746
747
        """Whether the graph is a multigraph

        Returns
        -------
        bool
            True if the graph is a multigraph, False otherwise.
        """
        if self._is_multigraph is None:
            return self._graph.is_multigraph()
        else:
            return self._is_multigraph
Minjie Wang's avatar
Minjie Wang committed
748
749
750

    @property
    def is_readonly(self):
Mufei Li's avatar
Mufei Li committed
751
752
753
754
755
756
757
        """Whether the graph is readonly

        Returns
        -------
        bool
            True if the graph is readonly, False otherwise.
        """
Minjie Wang's avatar
Minjie Wang committed
758
        return self._graph.is_readonly()
Da Zheng's avatar
Da Zheng committed
759

Minjie Wang's avatar
Minjie Wang committed
760
    def has_node(self, vid, ntype=None):
Mufei Li's avatar
Mufei Li committed
761
        """Whether the graph has a node with a particular id and type.
Da Zheng's avatar
Da Zheng committed
762
763
764
765
766

        Parameters
        ----------
        vid : int
            The node ID.
Minjie Wang's avatar
Minjie Wang committed
767
768
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
769
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
770
771
772
773

        Returns
        -------
        bool
Mufei Li's avatar
Mufei Li committed
774
            True if the node exists, False otherwise
Da Zheng's avatar
Da Zheng committed
775
776
777

        Examples
        --------
Minjie Wang's avatar
Minjie Wang committed
778
        >>> g.has_node(0, 'user')
Da Zheng's avatar
Da Zheng committed
779
        True
Minjie Wang's avatar
Minjie Wang committed
780
        >>> g.has_node(4, 'user')
Da Zheng's avatar
Da Zheng committed
781
782
783
784
785
786
        False

        See Also
        --------
        has_nodes
        """
Minjie Wang's avatar
Minjie Wang committed
787
        return self._graph.has_node(self.get_ntype_id(ntype), vid)
Da Zheng's avatar
Da Zheng committed
788

Minjie Wang's avatar
Minjie Wang committed
789
    def has_nodes(self, vids, ntype=None):
Mufei Li's avatar
Mufei Li committed
790
        """Whether the graph has nodes with ids and a particular type.
Da Zheng's avatar
Da Zheng committed
791
792
793
794
795

        Parameters
        ----------
        vid : list or tensor
            The array of node IDs.
Minjie Wang's avatar
Minjie Wang committed
796
797
798
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
            in the graph.
Da Zheng's avatar
Da Zheng committed
799
800
801
802

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
803
804
            Binary tensor indicating the existence of nodes with the specified ids and type.
            ``a[i]=1`` if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
805
806
807
808
809

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
810
        >>> g.has_nodes([0, 1, 2, 3, 4], 'user')
Da Zheng's avatar
Da Zheng committed
811
812
813
814
815
816
        tensor([1, 1, 1, 0, 0])

        See Also
        --------
        has_node
        """
817
        vids = utils.toindex(vids)
Minjie Wang's avatar
Minjie Wang committed
818
        rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
819
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
820

Minjie Wang's avatar
Minjie Wang committed
821
    def has_edge_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
822
        """Whether the graph has an edge (u, v) of type ``etype``.
Da Zheng's avatar
Da Zheng committed
823
824
825
826
827
828
829

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
830
831
832
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
833
834
835
836
837
838
839
840

        Returns
        -------
        bool
            True if the edge is in the graph, False otherwise.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
841

Minjie Wang's avatar
Minjie Wang committed
842
        >>> g.has_edge_between(0, 1, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
843
        True
Minjie Wang's avatar
Minjie Wang committed
844
        >>> g.has_edge_between(0, 2, ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
845
846
847
848
849
850
        False

        See Also
        --------
        has_edges_between
        """
Minjie Wang's avatar
Minjie Wang committed
851
        return self._graph.has_edge_between(self.get_etype_id(etype), u, v)
Da Zheng's avatar
Da Zheng committed
852

Minjie Wang's avatar
Minjie Wang committed
853
    def has_edges_between(self, u, v, etype=None):
Mufei Li's avatar
Mufei Li committed
854
        """Whether the graph has edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
855
856
857
858
859
860
861

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
862
863
864
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
865
866
867
868

        Returns
        -------
        a : tensor
Mufei Li's avatar
Mufei Li committed
869
870
            Binary tensor indicating the existence of edges. ``a[i]=1`` if the graph
            contains edge ``(u[i], v[i])`` of type ``etype``, 0 otherwise.
Da Zheng's avatar
Da Zheng committed
871
872
873
874
875

        Examples
        --------
        The following example uses PyTorch backend.

Minjie Wang's avatar
Minjie Wang committed
876
        >>> g.has_edges_between([0, 0], [1, 2], ('user', 'plays', 'game'))
Da Zheng's avatar
Da Zheng committed
877
878
879
880
881
882
        tensor([1, 0])

        See Also
        --------
        has_edge_between
        """
883
884
        u = utils.toindex(u)
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
885
        rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
886
        return rst.tousertensor()
Da Zheng's avatar
Da Zheng committed
887

Minjie Wang's avatar
Minjie Wang committed
888
    def predecessors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
889
        """Return the predecessors of node `v` in the graph with the specified
Da Zheng's avatar
Da Zheng committed
890
891
        edge type.

Mufei Li's avatar
Mufei Li committed
892
893
        Node `u` is a predecessor of `v` if an edge `(u, v)` with type `etype`
        exists in the graph.
Da Zheng's avatar
Da Zheng committed
894
895
896
897

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
898
            The destination node.
Minjie Wang's avatar
Minjie Wang committed
899
900
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
901
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
902
903
904
905

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
906
            Array of predecessor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
907
908
909
910
911

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
912
913
914
915
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([plays_g, devs_g])
        >>> g.predecessors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
916
        tensor([0, 1])
Mufei Li's avatar
Mufei Li committed
917
918
        >>> g.predecessors(0, 'develops')
        tensor([0])
Da Zheng's avatar
Da Zheng committed
919
920
921
922
923

        See Also
        --------
        successors
        """
Minjie Wang's avatar
Minjie Wang committed
924
        return self._graph.predecessors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
925

Minjie Wang's avatar
Minjie Wang committed
926
    def successors(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
927
        """Return the successors of node `v` in the graph with the specified edge
Da Zheng's avatar
Da Zheng committed
928
929
        type.

Mufei Li's avatar
Mufei Li committed
930
931
        Node `u` is a successor of `v` if an edge `(v, u)` with type `etype` exists
        in the graph.
Da Zheng's avatar
Da Zheng committed
932
933
934
935

        Parameters
        ----------
        v : int
Mufei Li's avatar
Mufei Li committed
936
            The source node.
Minjie Wang's avatar
Minjie Wang committed
937
938
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
939
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
940
941
942
943

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
944
            Array of successor node IDs with the specified edge type.
Da Zheng's avatar
Da Zheng committed
945
946
947
948
949

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
950
951
952
953
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> g.successors(0, 'plays')
Da Zheng's avatar
Da Zheng committed
954
        tensor([0])
Mufei Li's avatar
Mufei Li committed
955
956
        >>> g.successors(0, 'follows')
        tensor([1])
Da Zheng's avatar
Da Zheng committed
957
958
959
960
961

        See Also
        --------
        predecessors
        """
Minjie Wang's avatar
Minjie Wang committed
962
        return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
Da Zheng's avatar
Da Zheng committed
963

Minjie Wang's avatar
Minjie Wang committed
964
    def edge_id(self, u, v, force_multi=False, etype=None):
Da Zheng's avatar
Da Zheng committed
965
        """Return the edge ID, or an array of edge IDs, between source node
Mufei Li's avatar
Mufei Li committed
966
        `u` and destination node `v`, with the specified edge type
Da Zheng's avatar
Da Zheng committed
967
968
969
970
971
972
973

        Parameters
        ----------
        u : int
            The node ID of source type.
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
974
        force_multi : bool, optional
Da Zheng's avatar
Da Zheng committed
975
            If False, will return a single edge ID if the graph is a simple graph.
Minjie Wang's avatar
Minjie Wang committed
976
977
978
979
            If True, will always return an array. (Default: False)
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
980
981
982
983

        Returns
        -------
        int or tensor
Mufei Li's avatar
Mufei Li committed
984
            The edge ID if ``force_multi == True`` and the graph is a simple graph.
Da Zheng's avatar
Da Zheng committed
985
986
987
988
989
990
            The edge ID array otherwise.

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
991
992
993
994
995
996
997
998
999
1000
1001
1002
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge id.

        >>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game'))
        2
        >>> g.edge_id(1, 2, force_multi=True, etype=('user', 'follows', 'user'))
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
1003
1004
1005
1006
1007

        See Also
        --------
        edge_ids
        """
Minjie Wang's avatar
Minjie Wang committed
1008
        idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
1009
        return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
Da Zheng's avatar
Da Zheng committed
1010

Minjie Wang's avatar
Minjie Wang committed
1011
    def edge_ids(self, u, v, force_multi=False, etype=None):
Da Zheng's avatar
Da Zheng committed
1012
        """Return all edge IDs between source node array `u` and destination
Mufei Li's avatar
Mufei Li committed
1013
        node array `v` with the specified edge type.
Da Zheng's avatar
Da Zheng committed
1014
1015
1016
1017
1018
1019
1020

        Parameters
        ----------
        u : list, tensor
            The node ID array of source type.
        v : list, tensor
            The node ID array of destination type.
Minjie Wang's avatar
Minjie Wang committed
1021
        force_multi : bool, optional
Mufei Li's avatar
Mufei Li committed
1022
1023
            Whether to always treat the graph as a multigraph. See the
            "Returns" for their effects. (Default: False)
Minjie Wang's avatar
Minjie Wang committed
1024
1025
1026
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.
Da Zheng's avatar
Da Zheng committed
1027
1028
1029
1030

        Returns
        -------
        tensor, or (tensor, tensor, tensor)
Mufei Li's avatar
Mufei Li committed
1031
1032
1033
1034
1035
1036
1037
1038

            * If the graph is a simple graph and ``force_multi=False``, return
            a single edge ID array ``e``.  ``e[i]`` is the edge ID between ``u[i]``
            and ``v[i]``.

            * Otherwise, return three arrays ``(eu, ev, e)``.  ``e[i]`` is the ID
            of an edge between ``eu[i]`` and ``ev[i]``.  All edges between ``u[i]``
            and ``v[i]`` are returned.
Da Zheng's avatar
Da Zheng committed
1039
1040
1041

        Notes
        -----
Mufei Li's avatar
Mufei Li committed
1042
1043
1044
        If the graph is a simple graph, ``force_multi=False``, and no edge
        exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined
        and an empty tensor is returned.
Da Zheng's avatar
Da Zheng committed
1045
1046
1047
1048
1049

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for edge ids.

        >>> plays_g.edge_ids([0], [2], etype=('user', 'plays', 'game'))
        tensor([], dtype=torch.int64)
        >>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game'))
        tensor([2])
        >>> g.edge_ids([1], [2], force_multi=True, etype=('user', 'follows', 'user'))
        (tensor([1, 1]), tensor([2, 2]), tensor([1, 2]))
Da Zheng's avatar
Da Zheng committed
1064
1065
1066
1067
1068

        See Also
        --------
        edge_id
        """
1069
1070
        u = utils.toindex(u)
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1071
        src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
1072
1073
1074
1075
        if force_multi or self._graph.is_multigraph():
            return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
        else:
            return eid.tousertensor()
Da Zheng's avatar
Da Zheng committed
1076

Minjie Wang's avatar
Minjie Wang committed
1077
    def find_edges(self, eid, etype=None):
Mufei Li's avatar
Mufei Li committed
1078
1079
1080
        """Given an edge ID array with the specified type, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.
Da Zheng's avatar
Da Zheng committed
1081
1082
1083
1084
1085

        Parameters
        ----------
        eid : list, tensor
            The edge ID array.
Minjie Wang's avatar
Minjie Wang committed
1086
1087
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1088
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1101
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
1102
        >>> g.find_edges([0, 2], ('user', 'plays', 'game'))
Mufei Li's avatar
Mufei Li committed
1103
1104
1105
        (tensor([0, 1]), tensor([0, 2]))
        >>> g.find_edges([0, 2])
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1106
        """
1107
        eid = utils.toindex(eid)
Minjie Wang's avatar
Minjie Wang committed
1108
        src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
1109
        return src.tousertensor(), dst.tousertensor()
Da Zheng's avatar
Da Zheng committed
1110

Minjie Wang's avatar
Minjie Wang committed
1111
    def in_edges(self, v, form='uv', etype=None):
Mufei Li's avatar
Mufei Li committed
1112
        """Return the inbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1113
1114
1115
1116

        Parameters
        ----------
        v : int, list, tensor
Mufei Li's avatar
Mufei Li committed
1117
            The node id(s) of destination type.
Da Zheng's avatar
Da Zheng committed
1118
1119
1120
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1121
1122
1123
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1124
1125
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1126
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1127
1128
1129

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1130
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
Da Zheng's avatar
Da Zheng committed
1131
            All inbound edges to ``v`` are returned.
Mufei Li's avatar
Mufei Li committed
1132
1133
1134
1135
1136
1137
1138
1139

            * If ``form='eid'``, return a tensor for the ids of the
              inbound edges of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors
              ``(eu, ev, eid)``. ``eid[i]`` gives the ID of the
              edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``eu[i]`` is the source node of an edge to ``ev[i]``.
Da Zheng's avatar
Da Zheng committed
1140
1141
1142
1143
1144

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1145
1146
1147
1148
1149
1150
1151
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.in_edges([0, 2], form='eid')
        tensor([0, 2])
        >>> g.in_edges([0, 2], form='all')
        (tensor([0, 1]), tensor([0, 2]), tensor([0, 2]))
        >>> g.in_edges([0, 2], form='uv')
        (tensor([0, 1]), tensor([0, 2]))
Da Zheng's avatar
Da Zheng committed
1152
        """
1153
        v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1154
        src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
1155
1156
1157
1158
1159
1160
1161
1162
1163
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Mufei Li's avatar
Mufei Li committed
1164
1165
    def out_edges(self, u, form='uv', etype=None):
        """Return the outbound edges of the node(s) with the specified type.
Da Zheng's avatar
Da Zheng committed
1166
1167
1168

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1169
1170
        u : int, list, tensor
            The node id(s) of source type.
Da Zheng's avatar
Da Zheng committed
1171
1172
1173
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1174
1175
1176
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
Minjie Wang's avatar
Minjie Wang committed
1177
1178
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1179
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1180
1181
1182

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
        tensor or (tensor, tensor, tensor) or (tensor, tensor)
            All outbound edges from ``u`` are returned.

            * If ``form='eid'``, return a tensor for the ids of the outbound edges
              of the nodes with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1192
1193
1194
1195

        Examples
        --------

Mufei Li's avatar
Mufei Li committed
1196
1197
1198
1199
1200
1201
1202
        >>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
        >>> g.out_edges([0, 1], form='eid')
        tensor([0, 1, 2])
        >>> g.out_edges([0, 1], form='all')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([0, 1, 2]))
        >>> g.out_edges([0, 1], form='uv')
        (tensor([0, 1, 1]), tensor([0, 1, 2]))
Da Zheng's avatar
Da Zheng committed
1203
        """
Mufei Li's avatar
Mufei Li committed
1204
1205
        u = utils.toindex(u)
        src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
1206
1207
1208
1209
1210
1211
1212
1213
1214
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1215
    def all_edges(self, form='uv', order=None, etype=None):
Mufei Li's avatar
Mufei Li committed
1216
        """Return all edges with the specified type.
Da Zheng's avatar
Da Zheng committed
1217
1218
1219
1220
1221
1222

        Parameters
        ----------
        form : str, optional
            The return form. Currently support:

Mufei Li's avatar
Mufei Li committed
1223
1224
1225
1226
            - ``'eid'`` : one eid tensor
            - ``'all'`` : a tuple ``(u, v, eid)``
            - ``'uv'``  : a pair ``(u, v)``, default
        order : str or None
Da Zheng's avatar
Da Zheng committed
1227
1228
            The order of the returned edges. Currently support:

Mufei Li's avatar
Mufei Li committed
1229
1230
1231
            - ``'srcdst'`` : sorted by their src and dst ids.
            - ``'eid'``    : sorted by edge Ids.
            - ``None``     : arbitrary order, default
Minjie Wang's avatar
Minjie Wang committed
1232
1233
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1234
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1235
1236
1237

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1238
1239
1240
1241
1242
1243
1244
1245
        tensor or (tensor, tensor, tensor) or (tensor, tensor)

            * If ``form='eid'``, return a tensor for the ids of all edges
              with the specified type.
            * If ``form='all'``, return a 3-tuple of tensors ``(eu, ev, eid)``.
              ``eid[i]`` gives the ID of the edge from ``eu[i]`` to ``ev[i]``.
            * If ``form='uv'``, return a 2-tuple of tensors ``(eu, ev)``.
              ``ev[i]`` is the destination node of the edge from ``eu[i]``.
Da Zheng's avatar
Da Zheng committed
1246
1247
1248
1249
1250

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1251
1252
1253
1254
1255
1256
1257
        >>> g = dgl.bipartite([(1, 1), (0, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.all_edges(form='eid', order='srcdst')
        tensor([1, 0, 2])
        >>> g.all_edges(form='all', order='srcdst')
        (tensor([0, 1, 1]), tensor([0, 1, 2]), tensor([1, 0, 2]))
        >>> g.all_edges(form='uv', order='eid')
        (tensor([1, 0, 1]), tensor([1, 0, 2]))
Da Zheng's avatar
Da Zheng committed
1258
        """
Minjie Wang's avatar
Minjie Wang committed
1259
        src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
1260
1261
1262
1263
1264
1265
1266
1267
1268
        if form == 'all':
            return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
        elif form == 'uv':
            return (src.tousertensor(), dst.tousertensor())
        elif form == 'eid':
            return eid.tousertensor()
        else:
            raise DGLError('Invalid form:', form)

Minjie Wang's avatar
Minjie Wang committed
1269
    def in_degree(self, v, etype=None):
Mufei Li's avatar
Mufei Li committed
1270
        """Return the in-degree of node ``v`` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1271
1272
1273
1274
1275

        Parameters
        ----------
        v : int
            The node ID of destination type.
Minjie Wang's avatar
Minjie Wang committed
1276
1277
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1278
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1279
1280
1281
1282
1283
1284
1285
1286

        Returns
        -------
        int
            The in-degree.

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1297
        2
Mufei Li's avatar
Mufei Li committed
1298
1299
        >>> g.in_degree(0, 'follows')
        0
Da Zheng's avatar
Da Zheng committed
1300
1301
1302
1303
1304

        See Also
        --------
        in_degrees
        """
Minjie Wang's avatar
Minjie Wang committed
1305
        return self._graph.in_degree(self.get_etype_id(etype), v)
Da Zheng's avatar
Da Zheng committed
1306

Minjie Wang's avatar
Minjie Wang committed
1307
    def in_degrees(self, v=ALL, etype=None):
Mufei Li's avatar
Mufei Li committed
1308
        """Return the in-degrees of nodes v with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1309
1310
1311
1312

        Parameters
        ----------
        v : list, tensor, optional.
Mufei Li's avatar
Mufei Li committed
1313
1314
1315
            The node ID array of the destination type. Default is to return the
            degrees of all nodes.
        etype : str or tuple of str or None, optional
Minjie Wang's avatar
Minjie Wang committed
1316
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1317
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1318
1319
1320
1321

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1322
1323
            The in-degree array. ``d[i]`` gives the in-degree of node ``v[i]``
            with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1324
1325
1326
1327
1328

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.in_degrees(0, 'plays')
        tensor([2])
        >>> g.in_degrees(etype='follows')
        tensor([0, 1, 2])
Da Zheng's avatar
Da Zheng committed
1341
1342
1343
1344
1345

        See Also
        --------
        in_degree
        """
Minjie Wang's avatar
Minjie Wang committed
1346
1347
        etid = self.get_etype_id(etype)
        _, dtid = self._graph.metagraph.find_edge(etid)
1348
        if is_all(v):
Minjie Wang's avatar
Minjie Wang committed
1349
            v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)))
1350
1351
        else:
            v = utils.toindex(v)
Minjie Wang's avatar
Minjie Wang committed
1352
        return self._graph.in_degrees(etid, v).tousertensor()
Da Zheng's avatar
Da Zheng committed
1353

Mufei Li's avatar
Mufei Li committed
1354
1355
    def out_degree(self, u, etype=None):
        """Return the out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1356
1357
1358

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1359
        u : int
Da Zheng's avatar
Da Zheng committed
1360
            The node ID of source type.
Minjie Wang's avatar
Minjie Wang committed
1361
1362
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1363
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1364
1365
1366
1367

        Returns
        -------
        int
Mufei Li's avatar
Mufei Li committed
1368
            The out-degree of node `u` with edges of type ``etype``.
Da Zheng's avatar
Da Zheng committed
1369
1370
1371

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degree(0, 'plays')
Da Zheng's avatar
Da Zheng committed
1382
        1
Mufei Li's avatar
Mufei Li committed
1383
1384
        >>> g.out_degree(1, 'follows')
        2
Da Zheng's avatar
Da Zheng committed
1385

1386
1387
1388
1389
        See Also
        --------
        out_degrees
        """
Mufei Li's avatar
Mufei Li committed
1390
        return self._graph.out_degree(self.get_etype_id(etype), u)
1391

Mufei Li's avatar
Mufei Li committed
1392
1393
    def out_degrees(self, u=ALL, etype=None):
        """Return the out-degrees of nodes u with edges of type ``etype``.
1394
1395
1396

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1397
        u : list, tensor
1398
1399
            The node ID array of source type. Default is to return the degrees
            of all the nodes.
Minjie Wang's avatar
Minjie Wang committed
1400
1401
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1402
            in the graph. (Default: None)
1403
1404
1405
1406

        Returns
        -------
        d : tensor
Mufei Li's avatar
Mufei Li committed
1407
1408
            The out-degree array. ``d[i]`` gives the out-degree of node ``u[i]``
            with edges of type ``etype``.
1409
1410
1411
1412
1413

        Examples
        --------
        The following example uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])

        Query for node degree.

        >>> g.out_degrees(0, 'plays')
        tensor([1])
        >>> g.out_degrees(etype='follows')
        tensor([1, 2, 0])
1426
1427
1428
1429
1430

        See Also
        --------
        out_degree
        """
Minjie Wang's avatar
Minjie Wang committed
1431
1432
        etid = self.get_etype_id(etype)
        stid, _ = self._graph.metagraph.find_edge(etid)
Mufei Li's avatar
Mufei Li committed
1433
1434
        if is_all(u):
            u = utils.toindex(slice(0, self._graph.number_of_nodes(stid)))
1435
        else:
Mufei Li's avatar
Mufei Li committed
1436
1437
            u = utils.toindex(u)
        return self._graph.out_degrees(etid, u).tousertensor()
Minjie Wang's avatar
Minjie Wang committed
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459

    def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
        """Internal function to create a subgraph."""
        node_frames = [
            FrameRef(Frame(
                self._node_frames[i][induced_nodes_of_ntype],
                num_rows=len(induced_nodes_of_ntype)))
            for i, induced_nodes_of_ntype in enumerate(induced_nodes)]
        edge_frames = [
            FrameRef(Frame(
                self._edge_frames[i][induced_edges_of_etype],
                num_rows=len(induced_edges_of_etype)))
            for i, induced_edges_of_etype in enumerate(induced_edges)]

        hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
        hsg.is_subgraph = True
        for ntype, induced_nid in zip(self.ntypes, induced_nodes):
            hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
        for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
            hsg.edges[etype].data[EID] = induced_eid.tousertensor()

        return hsg
1460

Minjie Wang's avatar
Minjie Wang committed
1461
1462
    def subgraph(self, nodes):
        """Return the subgraph induced on given nodes.
1463

Minjie Wang's avatar
Minjie Wang committed
1464
1465
        The metagraph of the returned subgraph is the same as the parent graph.
        Features are copied from the original graph.
1466

Minjie Wang's avatar
Minjie Wang committed
1467
1468
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1469
1470
1471
        nodes : dict[str->list or iterable]
            A dictionary mapping node types to node ID array for constructing
            subgraph. All nodes must exist in the graph.
1472

Minjie Wang's avatar
Minjie Wang committed
1473
1474
1475
1476
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1477
1478
1479
1480

            The nodes and edges in the subgraph are relabeled using consecutive
            integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1481
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1482
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1483
            subgraph.
Mufei Li's avatar
Mufei Li committed
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.subgraph({'user': [4, 5]})
        An error occurs as these nodes do not exist.
        >>> sub_g = g.subgraph({'user': [1, 2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 0},
1504
              num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['follows'].edata[dgl.EID] # Get the edge indices in the raw graph
        tensor([1, 2])

        Get the copied node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        edge_subgraph
Minjie Wang's avatar
Minjie Wang committed
1528
1529
1530
1531
        """
        induced_nodes = [utils.toindex(nodes.get(ntype, [])) for ntype in self.ntypes]
        sgi = self._graph.node_subgraph(induced_nodes)
        induced_edges = sgi.induced_edges
1532

Minjie Wang's avatar
Minjie Wang committed
1533
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
1534

Minjie Wang's avatar
Minjie Wang committed
1535
1536
    def edge_subgraph(self, edges, preserve_nodes=False):
        """Return the subgraph induced on given edges.
1537

Minjie Wang's avatar
Minjie Wang committed
1538
        The metagraph of the returned subgraph is the same as the parent graph.
1539

Minjie Wang's avatar
Minjie Wang committed
1540
        Features are copied from the original graph.
1541

Minjie Wang's avatar
Minjie Wang committed
1542
1543
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1544
1545
1546
1547
1548
1549
1550
1551
1552
        edges : dict[str->list or iterable]
            A dictionary mapping edge types to edge ID array for constructing
            subgraph. All edges must exist in the subgraph.

            The edge types are characterized by triplets of
            ``(src type, etype, dst type)``.
        preserve_nodes : bool
            Whether to preserve all nodes or not. If false, all nodes
            without edges will be removed. (Default: False)
1553

Minjie Wang's avatar
Minjie Wang committed
1554
1555
1556
1557
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1558
1559
1560

            The nodes and edges are relabeled using consecutive integers from 0.

Minjie Wang's avatar
Minjie Wang committed
1561
            One can retrieve the mapping from subgraph node/edge ID to parent
Mufei Li's avatar
Mufei Li committed
1562
            node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
Minjie Wang's avatar
Minjie Wang committed
1563
            subgraph.
Mufei Li's avatar
Mufei Li committed
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> g.edge_subgraph({('user', 'follows', 'user'): [5, 6]})
        An error occurs as these edges do not exist.
        >>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): [1, 2],
        >>>                          ('user', 'plays', 'game'): [2]})
        >>> print(sub_g)
        Graph(num_nodes={'user': 2, 'game': 1},
1585
              num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
Mufei Li's avatar
Mufei Li committed
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
              metagraph=[('user', 'game'), ('user', 'user')])

        Get the original node/edge indices.

        >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
        tensor([1, 2])
        >>> sub_g['plays'].edata[dgl.EID]   # Get the edge indices in the raw graph
        tensor([2])

        Get the copied node features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are not shared.
        tensor([[0.],
                [1.],
                [2.]])

        See Also
        --------
        subgraph
Minjie Wang's avatar
Minjie Wang committed
1609
1610
1611
1612
1613
1614
1615
        """
        edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
        induced_edges = [
            utils.toindex(edges.get(canonical_etype, []))
            for canonical_etype in self.canonical_etypes]
        sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
        induced_nodes = sgi.induced_nodes
1616

Minjie Wang's avatar
Minjie Wang committed
1617
        return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
1618

Minjie Wang's avatar
Minjie Wang committed
1619
1620
    def node_type_subgraph(self, ntypes):
        """Return the subgraph induced on given node types.
1621

Mufei Li's avatar
Mufei Li committed
1622
1623
        The metagraph of the returned subgraph is the subgraph of the original
        metagraph induced from the node types.
1624

Minjie Wang's avatar
Minjie Wang committed
1625
        Features are shared with the original graph.
1626

Minjie Wang's avatar
Minjie Wang committed
1627
1628
1629
1630
        Parameters
        ----------
        ntypes : list[str]
            The node types
1631

Minjie Wang's avatar
Minjie Wang committed
1632
1633
1634
1635
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set node features
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.node_type_subgraph(['user'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}
              edata_schemes={})

        Get the shared node features.

        >>> sub_g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.nodes['user'].data['h'] += 1
        >>> g.nodes['user'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        edge_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
1672
1673
1674
1675
1676
1677
        """
        rel_graphs = []
        meta_edges = []
        induced_etypes = []
        node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
        edge_frames = []
1678

Minjie Wang's avatar
Minjie Wang committed
1679
1680
        ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
        srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
1681
        for i in range(len(self._etypes)):
Minjie Wang's avatar
Minjie Wang committed
1682
1683
            srctype = self._ntypes[srctype_id[i]]
            dsttype = self._ntypes[dsttype_id[i]]
1684

Minjie Wang's avatar
Minjie Wang committed
1685
1686
1687
1688
1689
            if srctype in ntypes and dsttype in ntypes:
                meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
                rel_graphs.append(self._graph.get_relation_graph(i))
                induced_etypes.append(self.etypes[i])
                edge_frames.append(self._edge_frames[i])
1690

Minjie Wang's avatar
Minjie Wang committed
1691
1692
1693
1694
        metagraph = graph_index.from_edge_list(meta_edges, True, True)
        hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
        hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
        return hg
1695

Minjie Wang's avatar
Minjie Wang committed
1696
1697
    def edge_type_subgraph(self, etypes):
        """Return the subgraph induced on given edge types.
1698

Minjie Wang's avatar
Minjie Wang committed
1699
1700
        The metagraph of the returned subgraph is the subgraph of the original metagraph
        induced from the edge types.
1701

Minjie Wang's avatar
Minjie Wang committed
1702
        Features are shared with the original graph.
1703

Minjie Wang's avatar
Minjie Wang committed
1704
1705
1706
1707
        Parameters
        ----------
        etypes : list[str or tuple]
            The edge types
1708

Minjie Wang's avatar
Minjie Wang committed
1709
1710
1711
1712
        Returns
        -------
        G : DGLHeteroGraph
            The subgraph.
Mufei Li's avatar
Mufei Li committed
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748

        Examples
        --------
        The following example uses PyTorch backend.

        Instantiate a heterograph.

        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
        >>> g = dgl.hetero_from_relations([plays_g, follows_g])
        >>> # Set edge features
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Get subgraphs.

        >>> sub_g = g.edge_type_subgraph(['follows'])
        >>> print(sub_g)
        Graph(num_nodes=3, num_edges=3,
              ndata_schemes={}
              edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})

        Get the shared edge features.

        >>> sub_g.edges['follows'].data['h']
        tensor([[0.],
                [1.],
                [2.]])
        >>> sub_g.edges['follows'].data['h'] += 1
        >>> g.edges['follows'].data['h']          # Features are shared.
        tensor([[1.],
                [2.],
                [3.]])

        See Also
        --------
        node_type_subgraph
Minjie Wang's avatar
Minjie Wang committed
1749
1750
1751
1752
1753
1754
        """
        etype_ids = [self.get_etype_id(etype) for etype in etypes]
        meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids))
        rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
        meta_src = meta_src.tonumpy()
        meta_dst = meta_dst.tonumpy()
1755
1756
1757
1758
        ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}
        mapped_meta_src = [ntypes_invmap[v] for v in meta_src]
        mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]
        node_frames = [self._node_frames[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
1759
        edge_frames = [self._edge_frames[i] for i in etype_ids]
1760
        induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
Minjie Wang's avatar
Minjie Wang committed
1761
1762
1763
1764
1765
1766
1767
        induced_etypes = [self._etypes[i] for i in etype_ids]   # get the "name" of edge type

        metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
        hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
        hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
        return hg

1768
    def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
Minjie Wang's avatar
Minjie Wang committed
1769
        """Return the adjacency matrix of edges of the given edge type.
1770

Minjie Wang's avatar
Minjie Wang committed
1771
1772
        By default, a row of returned adjacency matrix represents the
        destination of an edge and the column represents the source.
1773

Minjie Wang's avatar
Minjie Wang committed
1774
1775
        When transpose is True, a row represents the source and a column
        represents a destination.
1776
1777
1778

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
1779
1780
1781
1782
1783
        transpose : bool, optional
            A flag to transpose the returned adjacency matrix. (Default: False)
        ctx : context, optional
            The context of returned adjacency matrix. (Default: cpu)
        scipy_fmt : str, optional
Minjie Wang's avatar
Minjie Wang committed
1784
            If specified, return a scipy sparse matrix in the given format.
Mufei Li's avatar
Mufei Li committed
1785
            Otherwise, return a backend dependent sparse tensor. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
1786
1787
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1788
            in the graph. (Default: None)
1789

Minjie Wang's avatar
Minjie Wang committed
1790
1791
1792
1793
        Returns
        -------
        SparseTensor or scipy.sparse.spmatrix
            Adjacency matrix.
Mufei Li's avatar
Mufei Li committed
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816

        Examples
        --------

        Instantiate a heterogeneous graph.

        >>> follows_g = dgl.graph([(0, 0), (1, 1)], 'user', 'follows')
        >>> devs_g = dgl.bipartite([(0, 0), (1, 2)], 'developer', 'develops', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, devs_g])

        Get a backend dependent sparse tensor. Here we use PyTorch for example.

        >>> g.adjacency_matrix(etype='develops')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)

        Get a scipy coo sparse matrix.

        >>> g.adjacency_matrix(scipy_fmt='coo', etype='develops')
        <3x2 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in COOrdinate format>
1817
        """
Mufei Li's avatar
Mufei Li committed
1818
1819
1820
1821
1822
1823
1824
        if transpose is None:
            dgl_warning(
                "Currently adjacency_matrix() returns a matrix with destination as rows"
                " by default.  In 0.5 the result will have source as rows"
                " (i.e. transpose=True)")
            transpose = False

Minjie Wang's avatar
Minjie Wang committed
1825
1826
1827
1828
1829
        etid = self.get_etype_id(etype)
        if scipy_fmt is None:
            return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
        else:
            return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
1830

Minjie Wang's avatar
Minjie Wang committed
1831
1832
    # Alias of ``adjacency_matrix``
    adj = adjacency_matrix
1833

Minjie Wang's avatar
Minjie Wang committed
1834
1835
1836
    def incidence_matrix(self, typestr, ctx=F.cpu(), etype=None):
        """Return the incidence matrix representation of edges with the given
        edge type.
1837

Mufei Li's avatar
Mufei Li committed
1838
        An incidence matrix is an n-by-m sparse matrix, where n is
Minjie Wang's avatar
Minjie Wang committed
1839
1840
1841
        the number of nodes and m is the number of edges. Each nnz
        value indicating whether the edge is incident to the node
        or not.
1842

Mufei Li's avatar
Mufei Li committed
1843
        There are three types of incidence matrices :math:`I`:
Da Zheng's avatar
Da Zheng committed
1844

Minjie Wang's avatar
Minjie Wang committed
1845
        * ``in``:
Da Zheng's avatar
Da Zheng committed
1846

Minjie Wang's avatar
Minjie Wang committed
1847
1848
1849
            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
              (or :math:`v` is the dst node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.
Da Zheng's avatar
Da Zheng committed
1850

Minjie Wang's avatar
Minjie Wang committed
1851
        * ``out``:
Da Zheng's avatar
Da Zheng committed
1852

Minjie Wang's avatar
Minjie Wang committed
1853
1854
1855
1856
1857
1858
1859
1860
1861
            - :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
              (or :math:`v` is the src node of :math:`e`);
            - :math:`I[v, e] = 0` otherwise.

        * ``both`` (only if source and destination node type are the same):

            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
            - :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
            - :math:`I[v, e] = 0` otherwise (including self-loop).
Da Zheng's avatar
Da Zheng committed
1862
1863
1864

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
1865
1866
        typestr : str
            Can be either ``in``, ``out`` or ``both``
Mufei Li's avatar
Mufei Li committed
1867
1868
        ctx : context, optional
            The context of returned incidence matrix. (Default: cpu)
Minjie Wang's avatar
Minjie Wang committed
1869
1870
1871
1872
1873
1874
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
            in the graph.

        Returns
        -------
Mufei Li's avatar
Mufei Li committed
1875
        Framework SparseTensor
Minjie Wang's avatar
Minjie Wang committed
1876
            The incidence matrix.
Mufei Li's avatar
Mufei Li committed
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896

        Examples
        --------

        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
        >>> g.incidence_matrix('in')
        tensor(indices=tensor([[0, 2],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('out')
        tensor(indices=tensor([[0, 1],
                               [0, 1]]),
               values=tensor([1., 1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
        >>> g.incidence_matrix('both')
        tensor(indices=tensor([[1, 2],
                               [1, 1]]),
               values=tensor([-1.,  1.]),
               size=(3, 2), nnz=2, layout=torch.sparse_coo)
Da Zheng's avatar
Da Zheng committed
1897
        """
Minjie Wang's avatar
Minjie Wang committed
1898
1899
1900
1901
1902
        etid = self.get_etype_id(etype)
        return self._graph.incidence_matrix(etid, typestr, ctx)[0]

    # Alias of ``incidence_matrix``
    inc = incidence_matrix
Da Zheng's avatar
Da Zheng committed
1903

Minjie Wang's avatar
Minjie Wang committed
1904
1905
1906
1907
1908
    #################################################################
    # Features
    #################################################################

    def node_attr_schemes(self, ntype=None):
Mufei Li's avatar
Mufei Li committed
1909
        """Return the node feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
1910
1911

        Each feature scheme is a named tuple that stores the shape and data type
1912
        of the node feature.
Da Zheng's avatar
Da Zheng committed
1913
1914
1915

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
1916
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
1917
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
1918
1919
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
1920
1921
1922
1923
1924

        Returns
        -------
        dict of str to schemes
            The schemes of node feature columns.
1925
1926
1927
1928
1929

        Examples
        --------
        The following uses PyTorch backend.

Mufei Li's avatar
Mufei Li committed
1930
        >>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
1931
        >>> g.nodes['user'].data['h'] = torch.randn(3, 4)
1932
1933
        >>> g.node_attr_schemes('user')
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
1934
1935
1936
1937

        See Also
        --------
        edge_attr_schemes
Da Zheng's avatar
Da Zheng committed
1938
        """
Minjie Wang's avatar
Minjie Wang committed
1939
        return self._node_frames[self.get_ntype_id(ntype)].schemes
Da Zheng's avatar
Da Zheng committed
1940

Minjie Wang's avatar
Minjie Wang committed
1941
    def edge_attr_schemes(self, etype=None):
Mufei Li's avatar
Mufei Li committed
1942
        """Return the edge feature schemes for the specified type.
Da Zheng's avatar
Da Zheng committed
1943
1944

        Each feature scheme is a named tuple that stores the shape and data type
1945
        of the edge feature.
Da Zheng's avatar
Da Zheng committed
1946
1947
1948

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
1949
1950
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
1951
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
1952
1953
1954
1955

        Returns
        -------
        dict of str to schemes
Mufei Li's avatar
Mufei Li committed
1956
            The schemes of edge feature columns.
Da Zheng's avatar
Da Zheng committed
1957

1958
1959
1960
        Examples
        --------
        The following uses PyTorch backend.
Da Zheng's avatar
Da Zheng committed
1961

Mufei Li's avatar
Mufei Li committed
1962
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
1963
        >>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
1964
1965
        >>> g.edge_attr_schemes(('user', 'plays', 'game'))
        {'h': Scheme(shape=(4,), dtype=torch.float32)}
Mufei Li's avatar
Mufei Li committed
1966
1967
1968
1969

        See Also
        --------
        node_attr_schemes
Da Zheng's avatar
Da Zheng committed
1970
        """
Minjie Wang's avatar
Minjie Wang committed
1971
        return self._edge_frames[self.get_etype_id(etype)].schemes
Da Zheng's avatar
Da Zheng committed
1972

Minjie Wang's avatar
Minjie Wang committed
1973
1974
    def set_n_initializer(self, initializer, field=None, ntype=None):
        """Set the initializer for empty node features.
Da Zheng's avatar
Da Zheng committed
1975

Minjie Wang's avatar
Minjie Wang committed
1976
1977
        Initializer is a callable that returns a tensor given the shape, data type
        and device context.
Da Zheng's avatar
Da Zheng committed
1978

Minjie Wang's avatar
Minjie Wang committed
1979
        When a subset of the nodes are assigned a new feature, initializer is
Mufei Li's avatar
Mufei Li committed
1980
        used to create feature for the rest of the nodes.
Minjie Wang's avatar
Minjie Wang committed
1981
1982
1983
1984

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
1985
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
1986
        field : str, optional
Mufei Li's avatar
Mufei Li committed
1987
            The feature field name. Default is to set an initializer for all the
Minjie Wang's avatar
Minjie Wang committed
1988
1989
            feature fields.
        ntype : str, optional
Mufei Li's avatar
Mufei Li committed
1990
            The node type. Can be omitted if there is only one node
Minjie Wang's avatar
Minjie Wang committed
1991
1992
            type in the graph. Error will be raised otherwise.
            (Default: None)
Da Zheng's avatar
Da Zheng committed
1993

Minjie Wang's avatar
Minjie Wang committed
1994
1995
1996
1997
        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Da Zheng's avatar
Da Zheng committed
1998

Mufei Li's avatar
Mufei Li committed
1999
2000
2001
        See Also
        --------
        set_e_initializer
Da Zheng's avatar
Da Zheng committed
2002
        """
Minjie Wang's avatar
Minjie Wang committed
2003
2004
        ntid = self.get_ntype_id(ntype)
        self._node_frames[ntid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2005

Minjie Wang's avatar
Minjie Wang committed
2006
2007
    def set_e_initializer(self, initializer, field=None, etype=None):
        """Set the initializer for empty edge features.
Da Zheng's avatar
Da Zheng committed
2008

Minjie Wang's avatar
Minjie Wang committed
2009
2010
2011
2012
2013
2014
2015
2016
2017
        Initializer is a callable that returns a tensor given the shape, data
        type and device context.

        When a subset of the edges are assigned a new feature, initializer is
        used to create feature for rest of the edges.

        Parameters
        ----------
        initializer : callable
Mufei Li's avatar
Mufei Li committed
2018
            The initializer, mapping (shape, data type, context) to tensor.
Minjie Wang's avatar
Minjie Wang committed
2019
2020
2021
2022
2023
        field : str, optional
            The feature field name. Default is set an initializer for all the
            feature fields.
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2024
2025
            in the graph. Error will be raised otherwise.
            (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2026
2027
2028
2029
2030

        Note
        -----
        User defined initializer must follow the signature of
        :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Mufei Li's avatar
Mufei Li committed
2031
2032
2033
2034

        See Also
        --------
        set_n_initializer
Da Zheng's avatar
Da Zheng committed
2035
        """
Minjie Wang's avatar
Minjie Wang committed
2036
2037
        etid = self.get_etype_id(etype)
        self._edge_frames[etid].set_initializer(initializer, field)
Da Zheng's avatar
Da Zheng committed
2038

Minjie Wang's avatar
Minjie Wang committed
2039
2040
    def _set_n_repr(self, ntid, u, data, inplace=False):
        """Internal API to set node features.
Da Zheng's avatar
Da Zheng committed
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
        and (D1, D2, ...) be the shape of the node representation tensor. The
        length of the given node ids must match B (i.e, len(u) == B).

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2052
2053
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2054
2055
        u : node, container or tensor
            The node(s).
Minjie Wang's avatar
Minjie Wang committed
2056
2057
2058
        data : dict of tensor
            Node representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2059
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2060
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2061
        """
2062
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2063
            num_nodes = self._graph.number_of_nodes(ntid)
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
        else:
            u = utils.toindex(u)
            num_nodes = len(u)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_nodes:
                raise DGLError('Expect number of features to match number of nodes (len(u)).'
                               ' Got %d and %d instead.' % (nfeats, num_nodes))

        if is_all(u):
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2075
                self._node_frames[ntid][key] = val
2076
        else:
Minjie Wang's avatar
Minjie Wang committed
2077
            self._node_frames[ntid].update_rows(u, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2078

Minjie Wang's avatar
Minjie Wang committed
2079
    def _get_n_repr(self, ntid, u):
Da Zheng's avatar
Da Zheng committed
2080
2081
2082
2083
2084
2085
        """Get node(s) representation of a single node type.

        The returned feature tensor batches multiple node features on the first dimension.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2086
2087
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2088
2089
2090
2091
2092
2093
2094
2095
        u : node, container or tensor
            The node(s).

        Returns
        -------
        dict
            Representation dict from feature name to feature tensor.
        """
2096
        if is_all(u):
Minjie Wang's avatar
Minjie Wang committed
2097
            return dict(self._node_frames[ntid])
2098
2099
        else:
            u = utils.toindex(u)
Minjie Wang's avatar
Minjie Wang committed
2100
            return self._node_frames[ntid].select_rows(u)
Da Zheng's avatar
Da Zheng committed
2101

Minjie Wang's avatar
Minjie Wang committed
2102
2103
    def _pop_n_repr(self, ntid, key):
        """Internal API to get and remove the specified node feature.
Da Zheng's avatar
Da Zheng committed
2104
2105
2106

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2107
2108
        ntid : int
            Node type id.
Da Zheng's avatar
Da Zheng committed
2109
2110
2111
2112
2113
2114
2115
2116
        key : str
            The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2117
        return self._node_frames[ntid].pop(key)
Da Zheng's avatar
Da Zheng committed
2118

Minjie Wang's avatar
Minjie Wang committed
2119
2120
    def _set_e_repr(self, etid, edges, data, inplace=False):
        """Internal API to set edge(s) features.
Da Zheng's avatar
Da Zheng committed
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130

        `data` is a dictionary from the feature name to feature tensor. Each tensor
        is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
        and (D1, D2, ...) be the shape of the edge representation tensor.

        All update will be done out of place to work with autograd unless the
        inplace flag is true.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2131
2132
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2133
2134
2135
2136
2137
2138
2139
2140
        edges : edges
            Edges can be either

            * A pair of endpoint nodes (u, v), where u is the node ID of source
              node type and v is that of destination node type.
            * A tensor of edge ids of the given type.

            The default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2141
2142
2143
        data : tensor or dict of tensor
            Edge representation.
        inplace : bool, optional
Da Zheng's avatar
Da Zheng committed
2144
            If True, update will be done in place, but autograd will break.
Minjie Wang's avatar
Minjie Wang committed
2145
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2146
        """
2147
2148
2149
2150
2151
2152
2153
2154
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2155
            _, _, eid = self._graph.edge_ids(etid, u, v)
2156
2157
2158
2159
2160
2161
2162
2163
2164
        else:
            eid = utils.toindex(edges)

        # sanity check
        if not utils.is_dict_like(data):
            raise DGLError('Expect dictionary type for feature data.'
                           ' Got "%s" instead.' % type(data))

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2165
            num_edges = self._graph.number_of_edges(etid)
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
        else:
            eid = utils.toindex(eid)
            num_edges = len(eid)
        for key, val in data.items():
            nfeats = F.shape(val)[0]
            if nfeats != num_edges:
                raise DGLError('Expect number of features to match number of edges.'
                               ' Got %d and %d instead.' % (nfeats, num_edges))
        # set
        if is_all(eid):
            # update column
            for key, val in data.items():
Minjie Wang's avatar
Minjie Wang committed
2178
                self._edge_frames[etid][key] = val
2179
2180
        else:
            # update row
Minjie Wang's avatar
Minjie Wang committed
2181
            self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
Da Zheng's avatar
Da Zheng committed
2182

Minjie Wang's avatar
Minjie Wang committed
2183
2184
    def _get_e_repr(self, etid, edges):
        """Internal API to get edge features.
Da Zheng's avatar
Da Zheng committed
2185
2186
2187

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2188
2189
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2190
2191
2192
2193
2194
2195
2196
2197
2198
        edges : edges
            Edges can be a pair of endpoint nodes (u, v), or a
            tensor of edge ids. The default value is all the edges.

        Returns
        -------
        dict
            Representation dict
        """
2199
2200
2201
2202
2203
2204
2205
2206
        # parse argument
        if is_all(edges):
            eid = ALL
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2207
            _, _, eid = self._graph.edge_ids(etid, u, v)
2208
2209
2210
2211
        else:
            eid = utils.toindex(edges)

        if is_all(eid):
Minjie Wang's avatar
Minjie Wang committed
2212
            return dict(self._edge_frames[etid])
2213
2214
        else:
            eid = utils.toindex(eid)
Minjie Wang's avatar
Minjie Wang committed
2215
            return self._edge_frames[etid].select_rows(eid)
Da Zheng's avatar
Da Zheng committed
2216

Minjie Wang's avatar
Minjie Wang committed
2217
    def _pop_e_repr(self, etid, key):
Da Zheng's avatar
Da Zheng committed
2218
2219
2220
2221
        """Get and remove the specified edge repr of a single edge type.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2222
2223
        etid : int
            Edge type id.
Da Zheng's avatar
Da Zheng committed
2224
2225
2226
2227
2228
2229
2230
2231
        key : str
          The attribute name.

        Returns
        -------
        Tensor
            The popped representation
        """
Minjie Wang's avatar
Minjie Wang committed
2232
        self._edge_frames[etid].pop(key)
Da Zheng's avatar
Da Zheng committed
2233

Minjie Wang's avatar
Minjie Wang committed
2234
2235
2236
2237
2238
2239
2240
    #################################################################
    # Message passing
    #################################################################

    def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
        """Apply the function on the nodes with the same type to update their
        features.
Da Zheng's avatar
Da Zheng committed
2241

Minjie Wang's avatar
Minjie Wang committed
2242
        If None is provided for ``func``, nothing will happen.
Da Zheng's avatar
Da Zheng committed
2243
2244
2245

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2246
        func : callable or None
Minjie Wang's avatar
Minjie Wang committed
2247
2248
2249
            Apply function on the nodes. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        v : int or iterable of int or tensor, optional
Mufei Li's avatar
Mufei Li committed
2250
            The (type-specific) node (ids) on which to apply ``func``. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2251
2252
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
2253
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2254
2255
        inplace : bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2256
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2257
2258
2259

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2260
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
Minjie Wang's avatar
Minjie Wang committed
2261
2262
2263
        >>> g.nodes['user'].data['h'] = torch.ones(3, 5)
        >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
        >>> g.nodes['user'].data['h']
Da Zheng's avatar
Da Zheng committed
2264
2265
2266
        tensor([[2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2267
2268
2269
2270

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2271
        """
Minjie Wang's avatar
Minjie Wang committed
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
        ntid = self.get_ntype_id(ntype)
        if is_all(v):
            v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
        else:
            v_ntype = utils.toindex(v)
        with ir.prog() as prog:
            scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
                                           inplace=inplace)
            Runtime.run(prog)

    def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2283
2284
2285
2286
2287
2288
2289
        """Apply the function on the edges with the same type to update their
        features.

        If None is provided for ``func``, nothing will happen.

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2290
        func : callable or None
Da Zheng's avatar
Da Zheng committed
2291
2292
            Apply function on the edge. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Mufei Li's avatar
Mufei Li committed
2293
        edges : optional
Da Zheng's avatar
Da Zheng committed
2294
            Edges on which to apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2295
            edge specification. (Default: ALL)
Minjie Wang's avatar
Minjie Wang committed
2296
2297
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2298
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2299
2300
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2301
            (Default: False)
Da Zheng's avatar
Da Zheng committed
2302
2303
2304

        Examples
        --------
Mufei Li's avatar
Mufei Li committed
2305
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
Minjie Wang's avatar
Minjie Wang committed
2306
2307
2308
        >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
        >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
        >>> g.edges[('user', 'plays', 'game')].data['h']
Da Zheng's avatar
Da Zheng committed
2309
        tensor([[2., 2., 2., 2., 2.],
2310
                [2., 2., 2., 2., 2.],
Da Zheng's avatar
Da Zheng committed
2311
2312
                [2., 2., 2., 2., 2.],
                [2., 2., 2., 2., 2.]])
Mufei Li's avatar
Mufei Li committed
2313
2314
2315
2316
2317

        See Also
        --------
        apply_nodes
        group_apply_edges
Da Zheng's avatar
Da Zheng committed
2318
        """
Minjie Wang's avatar
Minjie Wang committed
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
            eid = utils.toindex(slice(0, self.number_of_edges(etype)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_apply_edges(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid, func, inplace=inplace)
            Runtime.run(prog)

    def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
Da Zheng's avatar
Da Zheng committed
2341
2342
2343
2344
2345
2346
2347
        """Group the edges by nodes and apply the function of the grouped
        edges to update their features.  The edges are of the same edge type
        (hence having the same source and destination node type).

        Parameters
        ----------
        group_by : str
Mufei Li's avatar
Mufei Li committed
2348
            Specify how to group edges. Expected to be either ``'src'`` or ``'dst'``
Minjie Wang's avatar
Minjie Wang committed
2349
        func : callable
Mufei Li's avatar
Mufei Li committed
2350
2351
2352
2353
2354
            Apply function on the edge. The function should be an
            :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should be
            (bucket_size, degrees, *feature_shape), and return the dict
            with values of the same shapes.
        edges : optional
Da Zheng's avatar
Da Zheng committed
2355
            Edges on which to group and apply ``func``. See :func:`send` for valid
Mufei Li's avatar
Mufei Li committed
2356
            edge specification. Default is all the edges.
Minjie Wang's avatar
Minjie Wang committed
2357
2358
        etype : str or tuple of str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2359
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2360
2361
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
            (Default: False)

        Examples
        --------
        >>> g = dgl.graph([(0, 1), (0, 2), (1, 2)], 'user', 'follows')
        >>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
        >>> def softmax_feat(edges):
        >>>     return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
        >>> g.group_apply_edges(group_by='src', func=softmax_feat)
        >>> g.edata['norm_feat']
        tensor([[0.3796],
                [0.6204],
                [1.0000]])

        See Also
        --------
        apply_edges
Da Zheng's avatar
Da Zheng committed
2379
        """
2380
2381
2382
        if group_by not in ('src', 'dst'):
            raise DGLError("Group_by should be either src or dst")

Minjie Wang's avatar
Minjie Wang committed
2383
2384
2385
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2386
            u, v, _ = self._graph.edges(etid, 'eid')
Minjie Wang's avatar
Minjie Wang committed
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
            eid = utils.toindex(slice(0, self.number_of_edges(etype)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)

        with ir.prog() as prog:
            scheduler.schedule_group_apply_edge(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                func, group_by,
                inplace=inplace)
            Runtime.run(prog)

    def send(self, edges, message_func, etype=None):
Da Zheng's avatar
Da Zheng committed
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
        """Send messages along the given edges with the same edge type.

        ``edges`` can be any of the following types:

        * ``int`` : Specify one edge using its edge id (of the given edge type).
        * ``pair of int`` : Specify one edge using its endpoints (of source node type
          and destination node type respectively).
        * ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids.
        * ``pair of int iterable`` / ``pair of tensors`` :
          Specify multiple edges using their endpoints.
2417

Mufei Li's avatar
Mufei Li committed
2418
        **Only works if the graph has one edge type.** For multiple types, use
2419
2420
2421
2422

        .. code::

           g['edgetype'].send(edges, message_func)
Da Zheng's avatar
Da Zheng committed
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432

        The UDF returns messages on the edges and can be later fetched in
        the destination node's ``mailbox``. Receiving will consume the messages.
        See :func:`recv` for example.

        If multiple ``send`` are triggered on the same edge without ``recv``. Messages
        generated by the later ``send`` will overwrite previous messages.

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2433
2434
        edges : optional
            Edges on which to apply ``message_func``.
2435
        message_func : callable
Da Zheng's avatar
Da Zheng committed
2436
2437
2438
2439
2440
2441
2442
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.

        Notes
        -----
        On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
        along all edges between :math:`u` and :math:`v`.
Mufei Li's avatar
Mufei Li committed
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463

        Examples
        --------

        >>> import dgl.function as fn
        >>> import torch
        >>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Different ways for sending messages.

        >>> # Send the feature of source nodes along all edges
        >>> g.send(g.edges(), fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its id
        >>> g.send(0, fn.copy_src('h', 'm'))
        >>> # Send the feature of source node along one edge specified by its end points
        >>> g.send((0, 1), fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their ids
        >>> g.send([0, 1], fn.copy_src('h', 'm'))
        >>> # Send the feature of source nodes along multiple edges specified by their end points
        >>> g.send(([0, 1], [1, 2]), fn.copy_src('h', 'm'))
Da Zheng's avatar
Da Zheng committed
2464
        """
2465
        assert message_func is not None
Minjie Wang's avatar
Minjie Wang committed
2466
2467
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2468
2469

        if is_all(edges):
Minjie Wang's avatar
Minjie Wang committed
2470
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2471
            u, v, _ = self._graph.edges(etid, 'eid')
2472
2473
2474
2475
2476
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2477
            u, v, eid = self._graph.edge_ids(etid, u, v)
2478
2479
        else:
            eid = utils.toindex(edges)
Minjie Wang's avatar
Minjie Wang committed
2480
            u, v, _ = self._graph.find_edges(etid, eid)
2481
2482
2483
2484
2485
2486

        if len(eid) == 0:
            # no edge to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2487
2488
2489
2490
            scheduler.schedule_send(
                AdaptedHeteroGraph(self, stid, dtid, etid),
                u, v, eid,
                message_func)
2491
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2492
2493

    def recv(self,
Minjie Wang's avatar
Minjie Wang committed
2494
2495
             v,
             reduce_func,
Da Zheng's avatar
Da Zheng committed
2496
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
2497
             etype=None,
Da Zheng's avatar
Da Zheng committed
2498
             inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2499
        r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
Da Zheng's avatar
Da Zheng committed
2500

Minjie Wang's avatar
Minjie Wang committed
2501
2502
2503
        It calculates:

        .. math::
Mufei Li's avatar
Mufei Li committed
2504
            h_v^{new} = \sigma(f(\{m_{uv} | u\in\mathcal{N}_{t}(v)\}))
Minjie Wang's avatar
Minjie Wang committed
2505

Mufei Li's avatar
Mufei Li committed
2506
2507
        where :math:`\mathcal{N}_t(v)` defines the predecessors of node(s) :math:`v` connected by
        edges of type :math:`t`, and :math:`m_{uv}` is the message on edge :math:`(u,v)`.
Minjie Wang's avatar
Minjie Wang committed
2508

Mufei Li's avatar
Mufei Li committed
2509
2510
        * ``reduce_func`` specifies :math:`f`, e.g. summation or average.
        * ``apply_node_func`` specifies :math:`\sigma`, e.g. ReLU activation.
Minjie Wang's avatar
Minjie Wang committed
2511
2512

        Other notes:
Da Zheng's avatar
Da Zheng committed
2513
2514
2515
2516
2517
2518

        * `reduce_func` will be skipped for nodes with no incoming message.
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.
Minjie Wang's avatar
Minjie Wang committed
2519
2520
        * The node features will be updated by the result of the ``reduce_func``.
        * Messages are consumed once received.
Mufei Li's avatar
Mufei Li committed
2521
        * The provided UDF may be called multiple times so it is recommended to provide
Minjie Wang's avatar
Minjie Wang committed
2522
          function with no side effect.
2523

Da Zheng's avatar
Da Zheng committed
2524
2525
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
2526
        v : int, container or tensor
Mufei Li's avatar
Mufei Li committed
2527
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
2528
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
2529
2530
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
2531
        apply_node_func : callable
Da Zheng's avatar
Da Zheng committed
2532
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2533
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2534
2535
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2536
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2537
2538
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
2563
        """
Minjie Wang's avatar
Minjie Wang committed
2564
2565
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2566
        if is_all(v):
Minjie Wang's avatar
Minjie Wang committed
2567
            v = F.arange(0, self.number_of_nodes(dtid))
2568
2569
2570
2571
2572
2573
2574
        elif isinstance(v, int):
            v = [v]
        v = utils.toindex(v)
        if len(v) == 0:
            # no vertex to be triggered.
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2575
2576
            scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v, reduce_func, apply_node_func,
2577
2578
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2579

Mufei Li's avatar
Mufei Li committed
2580
    def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2581
2582
2583
2584
2585
2586
        r"""Receive messages from multiple edge types and perform aggregation.

        It calculates:

        .. math::

Mufei Li's avatar
Mufei Li committed
2587
2588
2589
2590
            \begin{align}
            h_{v, t}^{new} &= f\left(\left\{m_{uv} | u\in\mathcal{N}_{t}(v)\right\}\right)\\
            h_v^{new} &= \sigma\left(g\left(\left\{h_{v, t}^{new} | t\in T_e\right\}\right)\right)
            \end{align}
Minjie Wang's avatar
Minjie Wang committed
2591

Mufei Li's avatar
Mufei Li committed
2592
2593
2594
2595
        * ``per_type_reducer`` is a dictionary mapping edge type (str or tuple of str) to
          reduce functions :math:`f` of each type.
        * ``cross_reducer`` specifies :math:`g`.
        * ``apply_node_func`` specifies :math:`\sigma`.
Minjie Wang's avatar
Minjie Wang committed
2596
2597
2598
2599
2600

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
2601
2602
        reducer_dict : dict of callable
            Mapping edge type (str or tuple of str) to reduce function (:mod:`Node UDF <dgl.udf>`).
Minjie Wang's avatar
Minjie Wang committed
2603
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
2604
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
2605
2606
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2607
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2608
2609
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Send and receive.

        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'), {'follows': fn.sum('m', 'h'),
        >>>              'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
        """
        # infer receive node type
        ntype = infer_ntype_from_dict(self, reducer_dict)
        ntid = self.get_ntype_id(ntype)
        if is_all(v):
            v = F.arange(0, self.number_of_nodes(ntid))
        elif isinstance(v, int):
            v = [v]
        v = utils.toindex(v)
        if len(v) == 0:
            return
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
        with ir.prog() as prog:
            for ety, args in reducer_dict.items():
                outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
                args = pad_tuple(args, 2)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be either '
Mufei Li's avatar
Mufei Li committed
2656
                                   '(1) reduce_func or (2) (reduce_func, apply_node_func)')
Minjie Wang's avatar
Minjie Wang committed
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
                rfunc, afunc = args
                etid = self.get_etype_id(ety)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
            Runtime.run(prog)
        # merge by cross_reducer
        self._node_frames[ntid].update(merge_frames(all_out, cross_reducer))
        # apply
Mufei Li's avatar
Mufei Li committed
2668
2669
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
2670

Da Zheng's avatar
Da Zheng committed
2671
2672
    def send_and_recv(self,
                      edges,
Minjie Wang's avatar
Minjie Wang committed
2673
2674
                      message_func,
                      reduce_func,
2675
                      apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
2676
                      etype=None,
Da Zheng's avatar
Da Zheng committed
2677
                      inplace=False):
Mufei Li's avatar
Mufei Li committed
2678
        """Send messages along edges of the specified type, and let destinations
Da Zheng's avatar
Da Zheng committed
2679
2680
        receive them.

Mufei Li's avatar
Mufei Li committed
2681
        Optionally, apply a function to update the node features after "receive".
Da Zheng's avatar
Da Zheng committed
2682
2683

        This is a convenient combination for performing
Mufei Li's avatar
Mufei Li committed
2684
2685
        :mod:`send <dgl.DGLHeteroGraph.send>` along the ``edges`` and
        :mod:`recv <dgl.DGLHeteroGraph.recv>` for the destinations of the ``edges``.
Da Zheng's avatar
Da Zheng committed
2686

Mufei Li's avatar
Mufei Li committed
2687
        **Only works if the graph has one edge type.**  For multiple types, use
2688
2689
2690

        .. code::

Mufei Li's avatar
Mufei Li committed
2691
2692
           g['edgetype'].send_and_recv(edges, message_func, reduce_func,
                                       apply_node_func, inplace=inplace)
2693

Da Zheng's avatar
Da Zheng committed
2694
2695
        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2696
2697
        edges : See :func:`send` for valid edge specification.
            Edges on which to apply ``func``.
Minjie Wang's avatar
Minjie Wang committed
2698
        message_func : callable
Da Zheng's avatar
Da Zheng committed
2699
2700
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
2701
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
2702
2703
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
2704
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
2705
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2706
2707
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
2708
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2709
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
2710
2711
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])

        Trigger "send" and "receive" separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.recv(g.nodes('user'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])

        Trigger "send" and "receive" in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.send_and_recv(g['follows'].edges(), fn.copy_src('h', 'm'),
        >>>                 fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
2746
        """
Minjie Wang's avatar
Minjie Wang committed
2747
2748
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2749
2750
2751
2752
2753
2754

        if isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
Minjie Wang's avatar
Minjie Wang committed
2755
            u, v, eid = self._graph.edge_ids(etid, u, v)
2756
2757
        else:
            eid = utils.toindex(edges)
Minjie Wang's avatar
Minjie Wang committed
2758
            u, v, _ = self._graph.find_edges(etid, eid)
2759
2760
2761
2762
2763
2764

        if len(u) == 0:
            # no edges to be triggered
            return

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2765
2766
2767
            scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                   (u, v, eid),
                                   message_func, reduce_func, apply_node_func,
2768
2769
                                   inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2770

Mufei Li's avatar
Mufei Li committed
2771
    def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2772
2773
        r"""Send and receive messages along multiple edge types and perform aggregation.

Mufei Li's avatar
Mufei Li committed
2774
        Optionally, apply a function to update the node features after "receive".
Minjie Wang's avatar
Minjie Wang committed
2775

Mufei Li's avatar
Mufei Li committed
2776
2777
2778
        This is a convenient combination for performing multiple
        :mod:`send <dgl.DGLHeteroGraph.send>` along edges of different types and
        :mod:`multi_recv <dgl.DGLHeteroGraph.multi_recv>` for the destinations of all edges.
Minjie Wang's avatar
Minjie Wang committed
2779
2780
2781

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (4-tuples). Each 4-tuple represents
            (edges, msg_func, reduce_func, apply_node_func):

            * edges: See send() for valid edge specification.
                  Edges on which to pass messages.
            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the node. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2798
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
2799
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
2800
2801
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2802
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2803
2804
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Trigger send and recv separately.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
        >>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
        >>> g.multi_recv(g.nodes('user'),
        >>>              {'follows': fn.sum('m', 'h'), 'attracts': fn.sum('m', 'h')}, "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])

        Trigger “send” and “receive” in one call.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_send_and_recv(
        >>>     {'follows': (g['follows'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (g['attracts'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [2.]])
Minjie Wang's avatar
Minjie Wang committed
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
        """
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
        dtid = self.get_ntype_id(ntype)

        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
        all_vs = []
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 4)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
2860
                                   '(edges, msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
                edges, mfunc, rfunc, afunc = args
                if isinstance(edges, tuple):
                    u, v = edges
                    u = utils.toindex(u)
                    v = utils.toindex(v)
                    # Rewrite u, v to handle edge broadcasting and multigraph.
                    u, v, eid = self._graph.edge_ids(etid, u, v)
                else:
                    eid = utils.toindex(edges)
                    u, v, _ = self._graph.find_edges(etid, eid)
                all_vs.append(v)
                if len(u) == 0:
                    # no edges to be triggered
                    continue
                scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
                                       (u, v, eid),
                                       mfunc, rfunc, afunc,
                                       inplace=inplace, outframe=outframe)
                all_out.append(outframe)
            Runtime.run(prog)
        # merge by cross_reducer
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
        # apply
Mufei Li's avatar
Mufei Li committed
2884
        if apply_node_func is not None:
Minjie Wang's avatar
Minjie Wang committed
2885
            dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
Mufei Li's avatar
Mufei Li committed
2886
            self.apply_nodes(apply_node_func, dstnodes, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
2887

Da Zheng's avatar
Da Zheng committed
2888
2889
    def pull(self,
             v,
Minjie Wang's avatar
Minjie Wang committed
2890
2891
             message_func,
             reduce_func,
2892
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
2893
             etype=None,
Da Zheng's avatar
Da Zheng committed
2894
2895
2896
2897
2898
             inplace=False):
        """Pull messages from the node(s)' predecessors and then update their features.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
2899
2900
2901
2902
2903
2904
        This is equivalent to :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>`
        on the incoming edges of ``v`` with the specified type.

        Other notes:

        * `reduce_func` will be skipped for nodes with no incoming messages.
Da Zheng's avatar
Da Zheng committed
2905
2906
2907
2908
2909
        * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
        * If some ``v`` have no incoming message, their new feature value will be calculated
          by the column initializer (see :func:`set_n_initializer`). The feature shapes and
          dtypes will be inferred.

Mufei Li's avatar
Mufei Li committed
2910
        **Only works if the graph has one edge type.** For multiple types, use
2911
2912
2913

        .. code::

Mufei Li's avatar
Mufei Li committed
2914
           g['edgetype'].pull(v, message_func, reduce_func, apply_node_func, inplace=inplace)
2915

Da Zheng's avatar
Da Zheng committed
2916
2917
        Parameters
        ----------
2918
        v : int, container or tensor, optional
Mufei Li's avatar
Mufei Li committed
2919
            The node(s) to be updated.
Minjie Wang's avatar
Minjie Wang committed
2920
        message_func : callable
Da Zheng's avatar
Da Zheng committed
2921
2922
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
2923
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
2924
2925
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
2926
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
2927
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
2928
2929
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
        etype : str or tuple of str, optional
Minjie Wang's avatar
Minjie Wang committed
2930
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
2931
            in the graph. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2932
2933
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
        >>> plays_g = dgl.bipartite([(0, 0), (2, 1)], 'user', 'plays', 'game')
        >>> g = dgl.hetero_from_relations([follows_g, plays_g])
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Pull.

        >>> g['follows'].pull(2, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [1.],
                [1.]])
Da Zheng's avatar
Da Zheng committed
2957
        """
Minjie Wang's avatar
Minjie Wang committed
2958
2959
2960
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
2961
2962
2963
2964
2965

        v = utils.toindex(v)
        if len(v) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
2966
2967
2968
            scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    v,
                                    message_func, reduce_func, apply_node_func,
2969
2970
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
2971

Mufei Li's avatar
Mufei Li committed
2972
    def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
Minjie Wang's avatar
Minjie Wang committed
2973
2974
2975
        r"""Pull and receive messages of the given nodes along multiple edge types
        and perform aggregation.

Mufei Li's avatar
Mufei Li committed
2976
2977
        This is equivalent to :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        on the incoming edges of ``v`` with the specified types.
Minjie Wang's avatar
Minjie Wang committed
2978
2979
2980
2981
2982

        Parameters
        ----------
        v : int, container or tensor
            The node(s) to be updated.
Mufei Li's avatar
Mufei Li committed
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
2997
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
2998
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
2999
3000
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3001
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3002
3003
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3004
            (Default: False)
Minjie Wang's avatar
Minjie Wang committed
3005

Mufei Li's avatar
Mufei Li committed
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(1, 1), (1, 0)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])

        Pull.

        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
        >>> g.multi_pull(1,
        >>>              {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>               'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [3.]])
        """
Minjie Wang's avatar
Minjie Wang committed
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
        v = utils.toindex(v)
        if len(v) == 0:
            return
        # infer receive node type
        ntype = infer_ntype_from_dict(self, etype_dict)
        dtid = self.get_ntype_id(ntype)
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = []
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, _ = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3048
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
                mfunc, rfunc, afunc = args
                scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
                                        v,
                                        mfunc, rfunc, afunc,
                                        inplace=inplace, outframe=outframe)
                all_out.append(outframe)
            Runtime.run(prog)
        # merge by cross_reducer
        self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
        # apply
Mufei Li's avatar
Mufei Li committed
3059
3060
        if apply_node_func is not None:
            self.apply_nodes(apply_node_func, v, ntype, inplace)
Minjie Wang's avatar
Minjie Wang committed
3061

Da Zheng's avatar
Da Zheng committed
3062
3063
    def push(self,
             u,
Minjie Wang's avatar
Minjie Wang committed
3064
3065
             message_func,
             reduce_func,
3066
             apply_node_func=None,
Minjie Wang's avatar
Minjie Wang committed
3067
             etype=None,
Da Zheng's avatar
Da Zheng committed
3068
3069
3070
             inplace=False):
        """Send message from the node(s) to their successors and update them.

Mufei Li's avatar
Mufei Li committed
3071
3072
3073
        This is equivalent to performing
        :mod:`send_and_recv <DGLHeteroGraph.send_and_recv>` along the outbound
        edges from ``u``.
Da Zheng's avatar
Da Zheng committed
3074

Mufei Li's avatar
Mufei Li committed
3075
        **Only works if the graph has one edge type.** For multiple types, use
3076
3077
3078

        .. code::

Mufei Li's avatar
Mufei Li committed
3079
           g['edgetype'].push(u, message_func, reduce_func, apply_node_func, inplace=inplace)
3080

Da Zheng's avatar
Da Zheng committed
3081
3082
        Parameters
        ----------
3083
        u : int, container or tensor
Mufei Li's avatar
Mufei Li committed
3084
            The node(s) to push out messages.
Minjie Wang's avatar
Minjie Wang committed
3085
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3086
3087
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3088
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3089
3090
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3091
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3092
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3093
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3094
3095
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3096
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3097
3098
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
            (Default: False)

        Examples
        --------

        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (0, 2)], 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

        Push.

        >>> g['follows'].push(0, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [0.]])
Da Zheng's avatar
Da Zheng committed
3120
        """
Minjie Wang's avatar
Minjie Wang committed
3121
3122
3123
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3124
3125
3126
3127
3128

        u = utils.toindex(u)
        if len(u) == 0:
            return
        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3129
3130
3131
            scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
                                    u,
                                    message_func, reduce_func, apply_node_func,
3132
3133
                                    inplace=inplace)
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3134
3135

    def update_all(self,
Minjie Wang's avatar
Minjie Wang committed
3136
3137
3138
3139
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Da Zheng's avatar
Da Zheng committed
3140
3141
3142
3143
        """Send messages through all edges and update all nodes.

        Optionally, apply a function to update the node features after receive.

Mufei Li's avatar
Mufei Li committed
3144
3145
3146
        This is equivalent to
        :mod:`send_and_recv <dgl.DGLHeteroGraph.send_and_recv>` over all edges
        of the specified type.
Da Zheng's avatar
Da Zheng committed
3147

Mufei Li's avatar
Mufei Li committed
3148
        **Only works if the graph has one edge type.** For multiple types, use
3149
3150
3151
3152
3153

        .. code::

           g['edgetype'].update_all(message_func, reduce_func, apply_node_func)

Da Zheng's avatar
Da Zheng committed
3154
3155
        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3156
        message_func : callable
Da Zheng's avatar
Da Zheng committed
3157
3158
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
Minjie Wang's avatar
Minjie Wang committed
3159
        reduce_func : callable
Da Zheng's avatar
Da Zheng committed
3160
3161
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
3162
        apply_node_func : callable, optional
Da Zheng's avatar
Da Zheng committed
3163
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3164
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3165
3166
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterograph.

        >>> g = dgl.graph([(0, 1), (1, 2), (2, 2)], 'user', 'follows')

        Update all.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3187
        """
Minjie Wang's avatar
Minjie Wang committed
3188
3189
3190
        # only one type of edges
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
3191
3192

        with ir.prog() as prog:
Minjie Wang's avatar
Minjie Wang committed
3193
3194
3195
            scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                          message_func, reduce_func,
                                          apply_node_func)
3196
            Runtime.run(prog)
Da Zheng's avatar
Da Zheng committed
3197

Mufei Li's avatar
Mufei Li committed
3198
    def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
Minjie Wang's avatar
Minjie Wang committed
3199
        r"""Send and receive messages along all edges.
Da Zheng's avatar
Da Zheng committed
3200

Mufei Li's avatar
Mufei Li committed
3201
3202
3203
        This is equivalent to
        :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
        over all edges.
Da Zheng's avatar
Da Zheng committed
3204
3205
3206

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
        etype_dict : dict
            Mapping an edge type (str or tuple of str) to the type specific
            configuration (3-tuples). Each 3-tuple represents
            (msg_func, reduce_func, apply_node_func):

            * msg_func: callable
                  Message function on the edges. The function should be
                  an :mod:`Edge UDF <dgl.udf>`.
            * reduce_func: callable
                  Reduce function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`.
            * apply_node_func : callable, optional
                  Apply function on the nodes. The function should be
                  a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3221
        cross_reducer : str
Mufei Li's avatar
Mufei Li committed
3222
            Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
Minjie Wang's avatar
Minjie Wang committed
3223
3224
        apply_node_func : callable
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3225
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3226
3227
        inplace: bool, optional
            If True, update will be done in place, but autograd will break.
Mufei Li's avatar
Mufei Li committed
3228
            (Default: False)
Da Zheng's avatar
Da Zheng committed
3229

Mufei Li's avatar
Mufei Li committed
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
        etype_dict : dict of callable
            ``update_all`` arguments per edge type.

        Examples
        --------
        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        Instantiate a heterograph.

        >>> g1 = dgl.graph([(0, 1), (1, 1)], 'user', 'follows')
        >>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
        >>> g = dgl.hetero_from_relations([g1, g2])
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

        Update all.

        >>> g.multi_update_all(
        >>>     {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
        >>>      'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
        >>> "sum")
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [4.]])
        """
Minjie Wang's avatar
Minjie Wang committed
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
        # TODO(minjie): currently loop over each edge type and reuse the old schedule.
        #   Should replace it with fused kernel.
        all_out = defaultdict(list)
        with ir.prog() as prog:
            for etype, args in etype_dict.items():
                etid = self.get_etype_id(etype)
                stid, dtid = self._graph.metagraph.find_edge(etid)
                outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
                args = pad_tuple(args, 3)
                if args is None:
                    raise DGLError('Invalid per-type arguments. Should be '
Mufei Li's avatar
Mufei Li committed
3268
                                   '(msg_func, reduce_func, [apply_node_func])')
Minjie Wang's avatar
Minjie Wang committed
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
                mfunc, rfunc, afunc = args
                scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
                                              mfunc, rfunc, afunc,
                                              outframe=outframe)
                all_out[dtid].append(outframe)
            Runtime.run(prog)
        for dtid, frames in all_out.items():
            # merge by cross_reducer
            self._node_frames[dtid].update(merge_frames(frames, cross_reducer))
            # apply
Mufei Li's avatar
Mufei Li committed
3279
3280
            if apply_node_func is not None:
                self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False)
Minjie Wang's avatar
Minjie Wang committed
3281
3282
3283
3284
3285
3286
3287

    def prop_nodes(self,
                   nodes_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3288
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3289
3290
3291
3292
3293
3294
        :func:`pull()` on nodes.

        The traversal order is specified by the ``nodes_generator``. It generates
        node frontiers, which is a list or a tensor of nodes. The nodes in the
        same frontier will be triggered together, while nodes in different frontiers
        will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3295
3296
3297

        Parameters
        ----------
Mufei Li's avatar
Mufei Li committed
3298
        nodes_generator : iterable, each element is a list or a tensor of node ids
Minjie Wang's avatar
Minjie Wang committed
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
            The generator of node frontiers. It specifies which nodes perform
            :func:`pull` at each timestep.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3309
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3310
3311
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3331

Minjie Wang's avatar
Minjie Wang committed
3332
3333
3334
        See Also
        --------
        prop_edges
Da Zheng's avatar
Da Zheng committed
3335
        """
Minjie Wang's avatar
Minjie Wang committed
3336
3337
        for node_frontier in nodes_generator:
            self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3338

Minjie Wang's avatar
Minjie Wang committed
3339
3340
3341
3342
3343
3344
    def prop_edges(self,
                   edges_generator,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
Mufei Li's avatar
Mufei Li committed
3345
        """Propagate messages using graph traversal by sequentially triggering
Minjie Wang's avatar
Minjie Wang committed
3346
        :func:`send_and_recv()` on edges.
Da Zheng's avatar
Da Zheng committed
3347

Minjie Wang's avatar
Minjie Wang committed
3348
3349
3350
        The traversal order is specified by the ``edges_generator``. It generates
        edge frontiers. The edge frontiers should be of *valid edges type*.
        See :func:`send` for more details.
Da Zheng's avatar
Da Zheng committed
3351

Mufei Li's avatar
Mufei Li committed
3352
        Edges in the same frontier will be triggered together, and edges in
Minjie Wang's avatar
Minjie Wang committed
3353
        different frontiers will be triggered according to the generating order.
Da Zheng's avatar
Da Zheng committed
3354
3355
3356

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
        edges_generator : generator
            The generator of edge frontiers.
        message_func : callable
            Message function on the edges. The function should be
            an :mod:`Edge UDF <dgl.udf>`.
        reduce_func : callable
            Reduce function on the node. The function should be
            a :mod:`Node UDF <dgl.udf>`.
        apply_node_func : callable, optional
            Apply function on the nodes. The function should be
Mufei Li's avatar
Mufei Li committed
3367
            a :mod:`Node UDF <dgl.udf>`. (Default: None)
Minjie Wang's avatar
Minjie Wang committed
3368
3369
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
            in the graph. (Default: None)

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn

        Instantiate a heterogrph and perform multiple rounds of message passing.

        >>> g = dgl.graph(([0, 1, 2, 3], [2, 3, 4, 4]), 'user', 'follows')
        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
        >>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_src('h', 'm'),
        >>>                         fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[1.],
                [2.],
                [1.],
                [2.],
                [3.]])
Da Zheng's avatar
Da Zheng committed
3390

Minjie Wang's avatar
Minjie Wang committed
3391
3392
3393
        See Also
        --------
        prop_nodes
Da Zheng's avatar
Da Zheng committed
3394
        """
Minjie Wang's avatar
Minjie Wang committed
3395
3396
3397
        for edge_frontier in edges_generator:
            self.send_and_recv(edge_frontier, message_func, reduce_func,
                               apply_node_func, etype=etype)
Da Zheng's avatar
Da Zheng committed
3398

Minjie Wang's avatar
Minjie Wang committed
3399
3400
3401
    #################################################################
    # Misc
    #################################################################
Da Zheng's avatar
Da Zheng committed
3402

Minjie Wang's avatar
Minjie Wang committed
3403
3404
    def to_networkx(self, node_attrs=None, edge_attrs=None):
        """Convert this graph to networkx graph.
Da Zheng's avatar
Da Zheng committed
3405

Minjie Wang's avatar
Minjie Wang committed
3406
        The edge id will be saved as the 'id' edge attribute.
Da Zheng's avatar
Da Zheng committed
3407
3408
3409

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3410
3411
3412
3413
        node_attrs : iterable of str, optional
            The node attributes to be copied.
        edge_attrs : iterable of str, optional
            The edge attributes to be copied.
Da Zheng's avatar
Da Zheng committed
3414
3415
3416

        Returns
        -------
Minjie Wang's avatar
Minjie Wang committed
3417
3418
        networkx.DiGraph
            The nx graph
Da Zheng's avatar
Da Zheng committed
3419

Minjie Wang's avatar
Minjie Wang committed
3420
3421
        Examples
        --------
Da Zheng's avatar
Da Zheng committed
3422

Minjie Wang's avatar
Minjie Wang committed
3423
3424
3425
        .. note:: Here we use pytorch syntax for demo. The general idea applies
            to other frameworks with minor syntax change (e.g. replace
            ``torch.tensor`` with ``mxnet.ndarray``).
Da Zheng's avatar
Da Zheng committed
3426

Minjie Wang's avatar
Minjie Wang committed
3427
3428
3429
3430
3431
        >>> import torch as th
        >>> g = DGLGraph()
        >>> g.add_nodes(5, {'n1': th.randn(5, 10)})
        >>> g.add_edges([0,1,3,4], [2,4,0,3], {'e1': th.randn(4, 6)})
        >>> nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
Da Zheng's avatar
Da Zheng committed
3432

Minjie Wang's avatar
Minjie Wang committed
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
        See Also
        --------
        dgl.to_networkx
        """
        # TODO(minjie): multi-type support
        assert len(self.ntypes) == 1
        assert len(self.etypes) == 1
        src, dst = self.edges()
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
        nx_graph = nx.MultiDiGraph() if self.is_multigraph else nx.DiGraph()
        nx_graph.add_nodes_from(range(self.number_of_nodes()))
        for eid, (u, v) in enumerate(zip(src, dst)):
            nx_graph.add_edge(u, v, id=eid)

        if node_attrs is not None:
            for nid, attr in nx_graph.nodes(data=True):
                feat_dict = self._get_n_repr(0, nid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
        if edge_attrs is not None:
            for _, _, attr in nx_graph.edges(data=True):
                eid = attr['id']
                feat_dict = self._get_e_repr(0, eid)
                attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
        return nx_graph

    def filter_nodes(self, predicate, nodes=ALL, ntype=None):
Da Zheng's avatar
Da Zheng committed
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
        """Return a tensor of node IDs with the given node type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(nodes) -> tensor``.
            ``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding node in
            the batch satisfies the predicate.
        nodes : int, iterable or tensor of ints
            The nodes to filter on. Default value is all the nodes.
Minjie Wang's avatar
Minjie Wang committed
3473
3474
        ntype : str, optional
            The node type. Can be omitted if there is only one node type
Mufei Li's avatar
Mufei Li committed
3475
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3476
3477
3478
3479

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
            Node ids indicating the nodes that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
        >>> g = dgl.graph([], 'user', 'follows', card=4)
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3491
        """
Minjie Wang's avatar
Minjie Wang committed
3492
3493
3494
3495
3496
3497
3498
3499
3500
        ntid = self.get_ntype_id(ntype)
        if is_all(nodes):
            v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)))
        else:
            v = utils.toindex(nodes)

        n_repr = self._get_n_repr(ntid, v)
        nbatch = NodeBatch(v, n_repr)
        n_mask = F.copy_to(predicate(nbatch), F.cpu())
Da Zheng's avatar
Da Zheng committed
3501

Minjie Wang's avatar
Minjie Wang committed
3502
3503
3504
3505
3506
3507
3508
        if is_all(nodes):
            return F.nonzero_1d(n_mask)
        else:
            nodes = F.tensor(nodes)
            return F.boolean_mask(nodes, n_mask)

    def filter_edges(self, predicate, edges=ALL, etype=None):
Da Zheng's avatar
Da Zheng committed
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
        """Return a tensor of edge IDs with the given edge type that satisfy
        the given predicate.

        Parameters
        ----------
        predicate : callable
            A function of signature ``func(edges) -> tensor``.
            ``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`.
            The ``tensor`` returned should be a 1-D boolean tensor with
            each element indicating whether the corresponding edge in
            the batch satisfies the predicate.
        edges : valid edges type
            Edges on which to apply ``func``. See :func:`send` for valid
            edges type. Default value is all the edges.
Minjie Wang's avatar
Minjie Wang committed
3523
3524
        etype : str, optional
            The edge type. Can be omitted if there is only one edge type
Mufei Li's avatar
Mufei Li committed
3525
            in the graph. (Default: None)
Da Zheng's avatar
Da Zheng committed
3526
3527
3528
3529

        Returns
        -------
        tensor
Mufei Li's avatar
Mufei Li committed
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
            Edge ids indicating the edges that satisfy the predicate.

        Examples
        --------
        >>> import torch
        >>> import dgl
        >>> import dgl.function as fn
        >>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows')
        >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
        >>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
        tensor([1, 2])
Da Zheng's avatar
Da Zheng committed
3541
        """
Minjie Wang's avatar
Minjie Wang committed
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
        etid = self.get_etype_id(etype)
        stid, dtid = self._graph.metagraph.find_edge(etid)
        if is_all(edges):
            u, v, _ = self._graph.edges(etid, 'eid')
            eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
        elif isinstance(edges, tuple):
            u, v = edges
            u = utils.toindex(u)
            v = utils.toindex(v)
            # Rewrite u, v to handle edge broadcasting and multigraph.
            u, v, eid = self._graph.edge_ids(etid, u, v)
        else:
            eid = utils.toindex(edges)
            u, v, _ = self._graph.find_edges(etid, eid)
Da Zheng's avatar
Da Zheng committed
3556

Minjie Wang's avatar
Minjie Wang committed
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
        src_data = self._get_n_repr(stid, u)
        edge_data = self._get_e_repr(etid, eid)
        dst_data = self._get_n_repr(dtid, v)
        ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
        e_mask = F.copy_to(predicate(ebatch), F.cpu())

        if is_all(edges):
            return F.nonzero_1d(e_mask)
        else:
            edges = F.tensor(edges)
            return F.boolean_mask(edges, e_mask)

    def to(self, ctx):  # pylint: disable=invalid-name
        """Move both ndata and edata to the targeted mode (cpu/gpu)
        Framework agnostic
Da Zheng's avatar
Da Zheng committed
3572
3573
3574

        Parameters
        ----------
Minjie Wang's avatar
Minjie Wang committed
3575
3576
3577
        ctx : framework-specific context object
            The context to move data to.

3578
3579
3580
3581
3582
        Returns
        -------
        g : DGLHeteroGraph
          Moved DGLHeteroGraph of the targeted mode.

Minjie Wang's avatar
Minjie Wang committed
3583
3584
3585
3586
3587
        Examples
        --------
        The following example uses PyTorch backend.

        >>> import torch
Mufei Li's avatar
Mufei Li committed
3588
3589
3590
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
3591
        >>> g = g.to(torch.device('cuda:0'))
Da Zheng's avatar
Da Zheng committed
3592
        """
Minjie Wang's avatar
Minjie Wang committed
3593
3594
3595
3596
3597
3598
        for i in range(len(self._node_frames)):
            for k in self._node_frames[i].keys():
                self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
        for i in range(len(self._edge_frames)):
            for k in self._edge_frames[i].keys():
                self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
3599
        return self
Da Zheng's avatar
Da Zheng committed
3600

Minjie Wang's avatar
Minjie Wang committed
3601
    def local_var(self):
Mufei Li's avatar
Mufei Li committed
3602
        """Return a heterograph object that can be used in a local function scope.
Minjie Wang's avatar
Minjie Wang committed
3603
3604
3605
3606
3607
3608
3609
3610

        The returned graph object shares the feature data and graph structure of this graph.
        However, any out-place mutation to the feature data will not reflect to this graph,
        thus making it easier to use in a function scope.

        If set, the local graph object will use same initializers for node features and
        edge features.

Mufei Li's avatar
Mufei Li committed
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
        Returns
        -------
        DGLHeteroGraph
            The graph object that can be used as a local variable.

        Notes
        -----
        Internally, the returned graph shares the same feature tensors, but construct a new
        dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
        graph will not reflect to the original graph. However, inplace operations do change
        the shared tensor values, so will be reflected to the original graph. This function
        also has little overhead when the number of feature tensors in this graph is small.

Minjie Wang's avatar
Minjie Wang committed
3624
3625
3626
3627
3628
3629
3630
3631
3632
        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
3633
3634
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
3635
        >>>
Mufei Li's avatar
Mufei Li committed
3636
3637
3638
3639
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
3640
3641
3642
3643
3644
3645

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     g = g.local_var()
Mufei Li's avatar
Mufei Li committed
3646
3647
3648
        >>>     # This 'h' feature will stay local and be GCed when the function exits
        >>>     g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>     return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
3649
        >>>
Mufei Li's avatar
Mufei Li committed
3650
3651
3652
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
        False

        See Also
        --------
        local_var
        """
        local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
        local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
        # Use same per-column initializers and default initializer.
        # If registered, a column (based on key) initializer will be used first,
        # otherwise the default initializer will be used.
        for fr1, fr2 in zip(local_node_frames, self._node_frames):
            sync_frame_initializer(fr1._frame, fr2._frame)
        for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
            sync_frame_initializer(fr1._frame, fr2._frame)
        return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
                              local_node_frames,
                              local_edge_frames)

    @contextmanager
    def local_scope(self):
        """Enter a local scope context for this graph.

        By entering a local scope, any out-place mutation to the feature data will
        not reflect to the original graph, thus making it easier to use in a function scope.

        If set, the local scope will use same initializers for node features and
        edge features.

        Examples
        --------
        The following example uses PyTorch backend.

        Avoid accidentally overriding existing feature data. This is quite common when
        implementing a NN module:

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
3691
3692
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
3693
        >>>
Mufei Li's avatar
Mufei Li committed
3694
3695
3696
3697
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
        >>> newh = foo(g)        # get tensor of all ones
        >>> print(g.edata['h'])  # still get tensor of all zeros
Minjie Wang's avatar
Minjie Wang committed
3698
3699
3700
3701
3702
3703

        Automatically garbage collect locally-defined tensors without the need to manually
        ``pop`` the tensors.

        >>> def foo(g):
        >>>     with g.local_scope():
Mufei Li's avatar
Mufei Li committed
3704
3705
3706
        >>>         # This 'h' feature will stay local and be GCed when the function exits
        >>>         g.edata['h'] = torch.ones((g.number_of_edges(), 3))
        >>>         return g.edata['h']
Minjie Wang's avatar
Minjie Wang committed
3707
        >>>
Mufei Li's avatar
Mufei Li committed
3708
3709
3710
        >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
        >>> h = foo(g)
        >>> print('h' in g.edata)
Minjie Wang's avatar
Minjie Wang committed
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
        False

        See Also
        --------
        local_var
        """
        old_nframes = self._node_frames
        old_eframes = self._edge_frames
        self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
        self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
        # Use same per-column initializers and default initializer.
        # If registered, a column (based on key) initializer will be used first,
        # otherwise the default initializer will be used.
        for fr1, fr2 in zip(self._node_frames, old_nframes):
            sync_frame_initializer(fr1._frame, fr2._frame)
        for fr1, fr2 in zip(self._edge_frames, old_eframes):
            sync_frame_initializer(fr1._frame, fr2._frame)
        yield
        self._node_frames = old_nframes
        self._edge_frames = old_eframes

############################################################
# Internal APIs
############################################################

def make_canonical_etypes(etypes, ntypes, metagraph):
    """Internal function to convert etype name to (srctype, etype, dsttype)

    Parameters
    ----------
    etypes : list of str
        Edge type list
    ntypes : list of str
        Node type list
    metagraph : GraphIndex
        Meta graph.

    Returns
    -------
    list of tuples (srctype, etype, dsttype)
    """
    # sanity check
    if len(etypes) != metagraph.number_of_edges():
        raise DGLError('Length of edge type list must match the number of '
                       'edges in the metagraph. {} vs {}'.format(
                           len(etypes), metagraph.number_of_edges()))
    if len(ntypes) != metagraph.number_of_nodes():
        raise DGLError('Length of nodes type list must match the number of '
                       'nodes in the metagraph. {} vs {}'.format(
                           len(ntypes), metagraph.number_of_nodes()))
    src, dst, eid = metagraph.edges()
    rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
    return rst

def infer_ntype_from_dict(graph, etype_dict):
    """Infer node type from dictionary of edge type to values.

    All the edge types in the dict must share the same destination node type
    and the node type will be returned. Otherwise, throw error.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    etype_dict : dict
        Dictionary whose key is edge type

    Returns
    -------
    str
        Node type
    """
    ntype = None
    for ety in etype_dict:
        _, _, dty = graph.to_canonical_etype(ety)
        if ntype is None:
            ntype = dty
        if ntype != dty:
            raise DGLError("Cannot infer destination node type from the dictionary. "
                           "A valid specification must make sure that all the edge "
                           "type keys share the same destination node type.")
    return ntype

def pad_tuple(tup, length, pad_val=None):
    """Pad the given tuple to the given length.

    If the input is not a tuple, convert it to a tuple of length one.
    Return None if pad fails.
    """
    if not isinstance(tup, tuple):
        tup = (tup, )
    if len(tup) > length:
        return None
    elif len(tup) == length:
        return tup
    else:
        return tup + (pad_val,) * (length - len(tup))

def merge_frames(frames, reducer):
    """Merge input frames into one. Resolve conflict fields using reducer.

    Parameters
    ----------
    frames : list of FrameRef
        Input frames
    reducer : str
        One of "sum", "max", "min", "mean", "stack"

    Returns
    -------
    FrameRef
        Merged frame
Da Zheng's avatar
Da Zheng committed
3823
    """
Minjie Wang's avatar
Minjie Wang committed
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
    if len(frames) == 1:
        return frames[0]
    if reducer == 'stack':
        # TODO(minjie): Stack order does not matter. However, it must
        #   be consistent! Need to enforce one type of order.
        def merger(flist):
            flist = [F.unsqueeze(f, 1) for f in flist]
            return F.stack(flist, 1)
    else:
        redfn = getattr(F, reducer, None)
        if redfn is None:
            raise DGLError('Invalid cross type reducer. Must be one of '
                           '"sum", "max", "min", "mean" or "stack".')
        def merger(flist):
            return redfn(F.stack(flist, 0), 0)
    ret = FrameRef(frame_like(frames[0]._frame))
    keys = set()
    for frm in frames:
        keys.update(frm.keys())
    for k in keys:
        flist = []
        for frm in frames:
            if k in frm:
                flist.append(frm[k])
        if len(flist) > 1:
            ret[k] = merger(flist)
        else:
            ret[k] = flist[0]
    return ret

def combine_frames(frames, ids):
    """Merge the frames into one frame, taking the common columns.

    Return None if there is no common columns.

Da Zheng's avatar
Da Zheng committed
3859
3860
    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
3861
3862
3863
3864
3865
3866
3867
3868
3869
    frames : List[FrameRef]
        List of frames
    ids : List[int]
        List of frame IDs

    Returns
    -------
    FrameRef
        The resulting frame
Da Zheng's avatar
Da Zheng committed
3870
    """
Minjie Wang's avatar
Minjie Wang committed
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
    # find common columns and check if their schemes match
    schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
    for frame_id in ids:
        frame = frames[frame_id]
        for key, scheme in list(schemes.items()):
            if key in frame.schemes:
                if frame.schemes[key] != scheme:
                    raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
                                   (key, frame.schemes[key], scheme))
            else:
                del schemes[key]

    if len(schemes) == 0:
        return None

    # concatenate the columns
    to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
    cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
    return FrameRef(Frame(cols))

def combine_names(names, ids=None):
    """Combine the selected names into one new name.

    Parameters
    ----------
    names : list of str
        String names
    ids : numpy.ndarray, optional
        Selected index

    Returns
    -------
    str
    """
    if ids is None:
        return '+'.join(sorted(names))
    else:
        selected = sorted([names[i] for i in ids])
        return '+'.join(selected)

class AdaptedHeteroGraph(GraphAdapter):
    """Adapt DGLGraph to interface required by scheduler.

    Parameters
    ----------
    graph : DGLHeteroGraph
        Graph
    stid : int
        Source node type id
    dtid : int
        Destination node type id
    etid : int
        Edge type id
    """
    def __init__(self, graph, stid, dtid, etid):
        self.graph = graph
        self.stid = stid
        self.dtid = dtid
        self.etid = etid
Da Zheng's avatar
Da Zheng committed
3930
3931

    @property
Minjie Wang's avatar
Minjie Wang committed
3932
3933
    def gidx(self):
        return self.graph._graph
Da Zheng's avatar
Da Zheng committed
3934

Minjie Wang's avatar
Minjie Wang committed
3935
3936
3937
    def num_src(self):
        """Number of source nodes."""
        return self.graph._graph.number_of_nodes(self.stid)
Da Zheng's avatar
Da Zheng committed
3938

Minjie Wang's avatar
Minjie Wang committed
3939
3940
3941
3942
3943
3944
3945
    def num_dst(self):
        """Number of destination nodes."""
        return self.graph._graph.number_of_nodes(self.dtid)

    def num_edges(self):
        """Number of edges."""
        return self.graph._graph.number_of_edges(self.etid)
Da Zheng's avatar
Da Zheng committed
3946
3947

    @property
Minjie Wang's avatar
Minjie Wang committed
3948
3949
3950
    def srcframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.stid]
Da Zheng's avatar
Da Zheng committed
3951

Minjie Wang's avatar
Minjie Wang committed
3952
3953
3954
3955
    @property
    def dstframe(self):
        """Frame to store source node features."""
        return self.graph._node_frames[self.dtid]
Da Zheng's avatar
Da Zheng committed
3956

Minjie Wang's avatar
Minjie Wang committed
3957
3958
3959
3960
    @property
    def edgeframe(self):
        """Frame to store edge features."""
        return self.graph._edge_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
3961

Minjie Wang's avatar
Minjie Wang committed
3962
3963
3964
3965
    @property
    def msgframe(self):
        """Frame to store messages."""
        return self.graph._msg_frames[self.etid]
Da Zheng's avatar
Da Zheng committed
3966

Minjie Wang's avatar
Minjie Wang committed
3967
3968
3969
3970
    @property
    def msgindicator(self):
        """Message indicator tensor."""
        return self.graph._get_msg_index(self.etid)
Da Zheng's avatar
Da Zheng committed
3971

Minjie Wang's avatar
Minjie Wang committed
3972
3973
3974
3975
    @msgindicator.setter
    def msgindicator(self, val):
        """Set new message indicator tensor."""
        self.graph._set_msg_index(self.etid, val)
Da Zheng's avatar
Da Zheng committed
3976

Minjie Wang's avatar
Minjie Wang committed
3977
3978
    def in_edges(self, nodes):
        return self.graph._graph.in_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
3979

Minjie Wang's avatar
Minjie Wang committed
3980
3981
    def out_edges(self, nodes):
        return self.graph._graph.out_edges(self.etid, nodes)
Da Zheng's avatar
Da Zheng committed
3982

Minjie Wang's avatar
Minjie Wang committed
3983
3984
    def edges(self, form):
        return self.graph._graph.edges(self.etid, form)
Da Zheng's avatar
Da Zheng committed
3985

Minjie Wang's avatar
Minjie Wang committed
3986
3987
3988
3989
3990
    def get_immutable_gidx(self, ctx):
        return self.graph._graph.get_unitgraph(self.etid, ctx)

    def bits_needed(self):
        return self.graph._graph.bits_needed(self.etid)