"""Module for graph index class definition."""
from __future__ import absolute_import

import numpy as np
import networkx as nx
import scipy

from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
from . import utils

class BoolFlag(object):
    """Bool flag with unknown value"""
    BOOL_UNKNOWN = -1
    BOOL_FALSE = 0
    BOOL_TRUE = 1

@register_object('graph.Graph')
class GraphIndex(ObjectBase):
    """Graph index object.

    Note
    ----
    Do not create GraphIndex directly, you can create graph index object using
    following functions:

    - `dgl.graph_index.from_edge_list`
    - `dgl.graph_index.from_scipy_sparse_matrix`
    - `dgl.graph_index.from_networkx`
    - `dgl.graph_index.from_shared_mem_csr_matrix`
    - `dgl.graph_index.from_csr`
    - `dgl.graph_index.from_coo`
    """
    def __new__(cls):
        obj = ObjectBase.__new__(cls)
        obj._multigraph = None  # python-side cache of the flag
        obj._readonly = None  # python-side cache of the flag
        obj._cache = {}
        return obj

    def __getstate__(self):
        src, dst, _ = self.edges()
        n_nodes = self.number_of_nodes()
        # TODO(minjie): should try to avoid calling is_multigraph
        multigraph = self.is_multigraph()
        readonly = self.is_readonly()

        return n_nodes, multigraph, readonly, src, dst

    def __setstate__(self, state):
        """The pickle state of GraphIndex is defined as a triplet
        (number_of_nodes, multigraph, readonly, src_nodes, dst_nodes)
        """
        num_nodes, multigraph, readonly, src, dst = state

        self._cache = {}
        self._multigraph = multigraph
        self._readonly = readonly
        if multigraph is None:
            multigraph = BoolFlag.BOOL_UNKNOWN
        self.__init_handle_by_constructor__(
            _CAPI_DGLGraphCreate,
            src.todgltensor(),
            dst.todgltensor(),
            int(multigraph),
            int(num_nodes),
            readonly)

    def add_nodes(self, num):
        """Add nodes.

        Parameters
        ----------
        num : int
            Number of nodes to be added.
        """
        _CAPI_DGLGraphAddVertices(self, int(num))
        self.clear_cache()

    def add_edge(self, u, v):
        """Add one edge.

        Parameters
        ----------
        u : int
            The src node.
        v : int
            The dst node.
        """
        _CAPI_DGLGraphAddEdge(self, int(u), int(v))
        self.clear_cache()

    def add_edges(self, u, v):
        """Add many edges.

        Parameters
        ----------
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.
        """
        u_array = u.todgltensor()
        v_array = v.todgltensor()
        _CAPI_DGLGraphAddEdges(self, u_array, v_array)
        self.clear_cache()

    def clear(self):
        """Clear the graph."""
        _CAPI_DGLGraphClear(self)
        self.clear_cache()

    def clear_cache(self):
        """Clear the cached graph structures."""
        self._cache.clear()

    def is_multigraph(self):
        """Return whether the graph is a multigraph

        Returns
        -------
        bool
            True if it is a multigraph, False otherwise.
        """
        if self._multigraph is None:
            self._multigraph = bool(_CAPI_DGLGraphIsMultigraph(self))
        return self._multigraph

    def is_readonly(self):
        """Indicate whether the graph index is read-only.

        Returns
        -------
        bool
            True if it is a read-only graph, False otherwise.
        """
        if self._readonly is None:
            self._readonly = bool(_CAPI_DGLGraphIsReadonly(self))
        return self._readonly

    def readonly(self, readonly_state=True):
        """Set the readonly state of graph index in-place.

        Parameters
        ----------
        readonly_state : bool
            New readonly state of current graph index.
        """
        # TODO(minjie): very ugly code, should fix this
        n_nodes, multigraph, _, src, dst = self.__getstate__()
        self.clear_cache()
        state = (n_nodes, multigraph, readonly_state, src, dst)
        self.__setstate__(state)

    def number_of_nodes(self):
        """Return the number of nodes.

        Returns
        -------
        int
            The number of nodes
        """
        return _CAPI_DGLGraphNumVertices(self)

    def number_of_edges(self):
        """Return the number of edges.

        Returns
        -------
        int
            The number of edges
        """
        return _CAPI_DGLGraphNumEdges(self)

    def has_node(self, vid):
        """Return true if the node exists.

        Parameters
        ----------
        vid : int
            The nodes

        Returns
        -------
        bool
            True if the node exists, False otherwise.
        """
        return bool(_CAPI_DGLGraphHasVertex(self, int(vid)))

    def has_nodes(self, vids):
        """Return true if the nodes exist.

        Parameters
        ----------
        vid : utils.Index
            The nodes

        Returns
        -------
        utils.Index
            0-1 array indicating existence
        """
        vid_array = vids.todgltensor()
        return utils.toindex(_CAPI_DGLGraphHasVertices(self, vid_array))

    def has_edge_between(self, u, v):
        """Return true if the edge exists.

        Parameters
        ----------
        u : int
            The src node.
        v : int
            The dst node.

        Returns
        -------
        bool
            True if the edge exists, False otherwise
        """
        return bool(_CAPI_DGLGraphHasEdgeBetween(self, int(u), int(v)))

    def has_edges_between(self, u, v):
        """Return true if the edge exists.

        Parameters
        ----------
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.

        Returns
        -------
        utils.Index
            0-1 array indicating existence
        """
        u_array = u.todgltensor()
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array))

    def predecessors(self, v, radius=1):
        """Return the predecessors of the node.

        Parameters
        ----------
        v : int
            The node.
        radius : int, optional
            The radius of the neighborhood.

        Returns
        -------
        utils.Index
            Array of predecessors
        """
        return utils.toindex(_CAPI_DGLGraphPredecessors(
            self, int(v), int(radius)))

    def successors(self, v, radius=1):
        """Return the successors of the node.

        Parameters
        ----------
        v : int
            The node.
        radius : int, optional
            The radius of the neighborhood.

        Returns
        -------
        utils.Index
            Array of successors
        """
        return utils.toindex(_CAPI_DGLGraphSuccessors(
            self, int(v), int(radius)))

    def edge_id(self, u, v):
        """Return the id array of all edges between u and v.

        Parameters
        ----------
        u : int
            The src node.
        v : int
            The dst node.

        Returns
        -------
        utils.Index
            The edge id array.
        """
        return utils.toindex(_CAPI_DGLGraphEdgeId(self, int(u), int(v)))

    def edge_ids(self, u, v):
        """Return a triplet of arrays that contains the edge IDs.

        Parameters
        ----------
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        u_array = u.todgltensor()
        v_array = v.todgltensor()
        edge_array = _CAPI_DGLGraphEdgeIds(self, u_array, v_array)

        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))

        return src, dst, eid

    def find_edge(self, eid):
        """Return the edge tuple of the given id.

        Parameters
        ----------
        eid : int
            The edge id.

        Returns
        -------
        int
            src node id
        int
            dst node id
        """
        ret = _CAPI_DGLGraphFindEdge(self, int(eid))
        return ret(0), ret(1)

    def find_edges(self, eid):
        """Return a triplet of arrays that contains the edge IDs.

        Parameters
        ----------
        eid : utils.Index
            The edge ids.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        eid_array = eid.todgltensor()
        edge_array = _CAPI_DGLGraphFindEdges(self, eid_array)

        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))

        return src, dst, eid

    def in_edges(self, v):
        """Return the in edges of the node(s).

        Parameters
        ----------
        v : utils.Index
            The node(s).

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if len(v) == 1:
            edge_array = _CAPI_DGLGraphInEdges_1(self, int(v[0]))
        else:
            v_array = v.todgltensor()
            edge_array = _CAPI_DGLGraphInEdges_2(self, v_array)
        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))
        return src, dst, eid

    def out_edges(self, v):
        """Return the out edges of the node(s).

        Parameters
        ----------
        v : utils.Index
            The node(s).

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if len(v) == 1:
            edge_array = _CAPI_DGLGraphOutEdges_1(self, int(v[0]))
        else:
            v_array = v.todgltensor()
            edge_array = _CAPI_DGLGraphOutEdges_2(self, v_array)
        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))
        return src, dst, eid

    @utils.cached_member(cache='_cache', prefix='edges')
    def edges(self, order=None):
        """Return all the edges

        Parameters
        ----------
        order : string
            The order of the returned edges. Currently support:

            - 'srcdst' : sorted by their src and dst ids.
            - 'eid'    : sorted by edge Ids.
            - None     : the arbitrary order.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if order is None:
            order = ""
        edge_array = _CAPI_DGLGraphEdges(self, order)
        src = edge_array(0)
        dst = edge_array(1)
        eid = edge_array(2)
        src = utils.toindex(src)
        dst = utils.toindex(dst)
        eid = utils.toindex(eid)
        return src, dst, eid

    def in_degree(self, v):
        """Return the in degree of the node.

        Parameters
        ----------
        v : int
            The node.

        Returns
        -------
        int
            The in degree.
        """
        return _CAPI_DGLGraphInDegree(self, int(v))

    def in_degrees(self, v):
        """Return the in degrees of the nodes.

        Parameters
        ----------
        v : utils.Index
            The nodes.

        Returns
        -------
        int
            The in degree array.
        """
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLGraphInDegrees(self, v_array))

    def out_degree(self, v):
        """Return the out degree of the node.

        Parameters
        ----------
        v : int
            The node.

        Returns
        -------
        int
            The out degree.
        """
        return _CAPI_DGLGraphOutDegree(self, int(v))

    def out_degrees(self, v):
        """Return the out degrees of the nodes.

        Parameters
        ----------
        v : utils.Index
            The nodes.

        Returns
        -------
        int
            The out degree array.
        """
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLGraphOutDegrees(self, v_array))

    def node_subgraph(self, v):
        """Return the induced node subgraph.

        Parameters
        ----------
        v : utils.Index
            The nodes.

        Returns
        -------
        SubgraphIndex
            The subgraph index.
        """
        v_array = v.todgltensor()
        rst = _CAPI_DGLGraphVertexSubgraph(self, v_array)
        induced_edges = utils.toindex(rst(2))
        gidx = rst(0)
        return SubgraphIndex(gidx, self, v, induced_edges)

    def node_subgraphs(self, vs_arr):
        """Return the induced node subgraphs.

        Parameters
        ----------
        vs_arr : a list of utils.Index
            The nodes.

        Returns
        -------
        a vector of SubgraphIndex
            The subgraph index.
        """
        gis = []
        for v in vs_arr:
            gis.append(self.node_subgraph(v))
        return gis

    def edge_subgraph(self, e, preserve_nodes=False):
        """Return the induced edge subgraph.

        Parameters
        ----------
        e : utils.Index
            The edges.
        preserve_nodes : bool
            Indicates whether to preserve all nodes or not.
            If true, keep the nodes which have no edge connected in the subgraph;
            If false, all nodes without edge connected to it would be removed.

        Returns
        -------
        SubgraphIndex
            The subgraph index.
        """
        e_array = e.todgltensor()
        rst = _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
        induced_nodes = utils.toindex(rst(1))
        gidx = rst(0)
        return SubgraphIndex(gidx, self, induced_nodes, e)

    @utils.cached_member(cache='_cache', prefix='scipy_adj')
    def adjacency_matrix_scipy(self, transpose, fmt):
        """Return the scipy adjacency matrix representation of this graph.

        By default, a row of returned adjacency matrix represents the destination
        of an edge and the column represents the source.

        When transpose is True, a row represents the source and a column represents
        a destination.

        The elements in the adajency matrix are edge ids.

        Parameters
        ----------
        transpose : bool
            A flag to transpose the returned adjacency matrix.
        fmt : str
            Indicates the format of returned adjacency matrix.

        Returns
        -------
        scipy.sparse.spmatrix
            The scipy representation of adjacency matrix.
        """
        if not isinstance(transpose, bool):
            raise DGLError('Expect bool value for "transpose" arg,'
                           ' but got %s.' % (type(transpose)))
        rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
        if fmt == "csr":
            indptr = utils.toindex(rst(0)).tonumpy()
            indices = utils.toindex(rst(1)).tonumpy()
            shuffle = utils.toindex(rst(2)).tonumpy()
            n = self.number_of_nodes()
            return scipy.sparse.csr_matrix((shuffle, indices, indptr), shape=(n, n))
        elif fmt == 'coo':
            idx = utils.toindex(rst(0)).tonumpy()
            n = self.number_of_nodes()
            m = self.number_of_edges()
            row, col = np.reshape(idx, (2, m))
            shuffle = np.arange(0, m)
            return scipy.sparse.coo_matrix((shuffle, (row, col)), shape=(n, n))
        else:
            raise Exception("unknown format")

    @utils.cached_member(cache='_cache', prefix='immu_gidx')
    def get_immutable_gidx(self, ctx):
        """Create an immutable graph index and copy to the given device context.

        Note: this internal function is for DGL scheduler use only

        Parameters
        ----------
        ctx : DGLContext
            The context of the returned graph.

        Returns
        -------
        GraphIndex
        """
        return self.to_immutable().asbits(self.bits_needed()).copy_to(ctx)

    def get_csr_shuffle_order(self):
        """Return the edge shuffling order when a coo graph is converted to csr format

        Returns
        -------
        tuple of two utils.Index
            The first element of the tuple is the shuffle order for outward graph
            The second element of the tuple is the shuffle order for inward graph
        """
        csr = _CAPI_DGLGraphGetAdj(self, True, "csr")
        order = csr(2)
        rev_csr = _CAPI_DGLGraphGetAdj(self, False, "csr")
        rev_order = rev_csr(2)
        return utils.toindex(order), utils.toindex(rev_order)

    def adjacency_matrix(self, transpose, ctx):
        """Return the adjacency matrix representation of this graph.

        By default, a row of returned adjacency matrix represents the destination
        of an edge and the column represents the source.

        When transpose is True, a row represents the source and a column represents
        a destination.

        Parameters
        ----------
        transpose : bool
            A flag to transpose the returned adjacency matrix.
        ctx : context
            The context of the returned matrix.

        Returns
        -------
        SparseTensor
            The adjacency matrix.
        utils.Index
            A index for data shuffling due to sparse format change. Return None
            if shuffle is not required.
        """
        if not isinstance(transpose, bool):
            raise DGLError('Expect bool value for "transpose" arg,'
                           ' but got %s.' % (type(transpose)))
        fmt = F.get_preferred_sparse_format()
        rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
        if fmt == "csr":
            indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
            indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
            shuffle = utils.toindex(rst(2))
            dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)
            spmat = F.sparse_matrix(dat, ('csr', indices, indptr),
                                    (self.number_of_nodes(), self.number_of_nodes()))[0]
            return spmat, shuffle
        elif fmt == "coo":
            ## FIXME(minjie): data type
            idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
            m = self.number_of_edges()
            idx = F.reshape(idx, (2, m))
            dat = F.ones((m,), dtype=F.float32, ctx=ctx)
            n = self.number_of_nodes()
            adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n))
            shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
            return adj, shuffle_idx
        else:
            raise Exception("unknown format")

    def incidence_matrix(self, typestr, ctx):
        """Return the incidence matrix representation of this graph.

        An incidence matrix is an n x m sparse matrix, where n is
        the number of nodes and m is the number of edges. Each nnz
        value indicating whether the edge is incident to the node
        or not.

        There are three types of an incidence matrix `I`:
        * "in":
          - I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);
          - I[v, e] = 0 otherwise.
        * "out":
          - I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);
          - I[v, e] = 0 otherwise.
        * "both":
          - I[v, e] = 1 if e is the in-edge of v;
          - I[v, e] = -1 if e is the out-edge of v;
          - I[v, e] = 0 otherwise (including self-loop).

        Parameters
        ----------
        typestr : str
            Can be either "in", "out" or "both"
        ctx : context
            The context of returned incidence matrix.

        Returns
        -------
        SparseTensor
            The incidence matrix.
        utils.Index
            A index for data shuffling due to sparse format change. Return None
            if shuffle is not required.
        """
        src, dst, eid = self.edges()
        src = src.tousertensor(ctx)  # the index of the ctx will be cached
        dst = dst.tousertensor(ctx)  # the index of the ctx will be cached
        eid = eid.tousertensor(ctx)  # the index of the ctx will be cached
        n = self.number_of_nodes()
        m = self.number_of_edges()
        if typestr == 'in':
            row = F.unsqueeze(dst, 0)
            col = F.unsqueeze(eid, 0)
            idx = F.cat([row, col], dim=0)
            # FIXME(minjie): data type
            dat = F.ones((m,), dtype=F.float32, ctx=ctx)
            inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
        elif typestr == 'out':
            row = F.unsqueeze(src, 0)
            col = F.unsqueeze(eid, 0)
            idx = F.cat([row, col], dim=0)
            # FIXME(minjie): data type
            dat = F.ones((m,), dtype=F.float32, ctx=ctx)
            inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
        elif typestr == 'both':
            # first remove entries for self loops
            mask = F.logical_not(F.equal(src, dst))
            src = F.boolean_mask(src, mask)
            dst = F.boolean_mask(dst, mask)
            eid = F.boolean_mask(eid, mask)
            n_entries = F.shape(src)[0]
            # create index
            row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
            col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
            idx = F.cat([row, col], dim=0)
            # FIXME(minjie): data type
            x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
            y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
            dat = F.cat([x, y], dim=0)
            inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
        else:
            raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
        shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
        return inc, shuffle_idx

    def to_networkx(self):
        """Convert to networkx graph.

        The edge id will be saved as the 'id' edge attribute.

        Returns
        -------
        networkx.DiGraph
            The nx graph
        """
        src, dst, eid = self.edges()
        ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
        ret.add_nodes_from(range(self.number_of_nodes()))
        for u, v, e in zip(src, dst, eid):
            ret.add_edge(u, v, id=e)
        return ret

    def line_graph(self, backtracking=True):
        """Return the line graph of this graph.

        Parameters
        ----------
        backtracking : bool, optional (default=False)
          Whether (i, j) ~ (j, i) in L(G).
          (i, j) ~ (j, i) is the behavior of networkx.line_graph.

        Returns
        -------
        GraphIndex
            The line graph of this graph.
        """
        return _CAPI_DGLGraphLineGraph(self, backtracking)

    def to_immutable(self):
        """Convert this graph index to an immutable one.

        Returns
        -------
        GraphIndex
            An immutable graph index.
        """
        return _CAPI_DGLToImmutable(self)

    def ctx(self):
        """Return the context of this graph index.

        Returns
        -------
        DGLContext
            The context of the graph.
        """
        return _CAPI_DGLGraphContext(self)

    def copy_to(self, ctx):
        """Copy this immutable graph index to the given device context.

        NOTE: this method only works for immutable graph index

        Parameters
        ----------
        ctx : DGLContext
            The target device context.

        Returns
        -------
        GraphIndex
            The graph index on the given device context.
        """
        return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id)

    def copyto_shared_mem(self, edge_dir, shared_mem_name):
        """Copy this immutable graph index to shared memory.

        NOTE: this method only works for immutable graph index

        Parameters
        ----------
        edge_dir : string
            Indicate which CSR should copy ("in", "out", "both").
        shared_mem_name : string
            The name of the shared memory.

        Returns
        -------
        GraphIndex
            The graph index on the given device context.
        """
        return _CAPI_DGLImmutableGraphCopyToSharedMem(self, edge_dir, shared_mem_name)

    def nbits(self):
        """Return the number of integer bits used in the storage (32 or 64).

        Returns
        -------
        int
            The number of bits.
        """
        return _CAPI_DGLGraphNumBits(self)

    def bits_needed(self):
        """Return the number of integer bits needed to represent the graph

        Returns
        -------
        int
            The number of bits needed
        """
        if self.number_of_edges() >= 0x80000000 or self.number_of_nodes() >= 0x80000000:
            return 64
        else:
            return 32

    def asbits(self, bits):
        """Transform the graph to a new one with the given number of bits storage.

        NOTE: this method only works for immutable graph index

        Parameters
        ----------
        bits : int
            The number of integer bits (32 or 64)

        Returns
        -------
        GraphIndex
            The graph index stored using the given number of bits.
        """
        return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))

class SubgraphIndex(object):
    """Internal subgraph data structure.

    Parameters
    ----------
    graph : GraphIndex
        The graph structure of this subgraph.
    parent : GraphIndex
        The parent graph index.
    induced_nodes : utils.Index
        The parent node ids in this subgraph.
    induced_edges : utils.Index
        The parent edge ids in this subgraph.
    """
    def __init__(self, graph, parent, induced_nodes, induced_edges):
        self.graph = graph
        self.parent = parent
        self.induced_nodes = induced_nodes
        self.induced_edges = induced_edges

    def __getstate__(self):
        raise NotImplementedError(
            "SubgraphIndex pickling is not supported yet.")

    def __setstate__(self, state):
        raise NotImplementedError(
            "SubgraphIndex unpickling is not supported yet.")


###############################################################
# Conversion functions
###############################################################
def from_coo(num_nodes, src, dst, is_multigraph, readonly):
    """Convert from coo arrays.

    Parameters
    ----------
    num_nodes : int
        Number of nodes.
    src : Tensor
        Src end nodes of the edges.
    dst : Tensor
        Dst end nodes of the edges.
    is_multigraph : bool or None
        True if the graph is a multigraph. None means determined by data.
    readonly : bool
        True if the returned graph is readonly.

    Returns
    -------
    GraphIndex
        The graph index.
    """
    src = utils.toindex(src)
    dst = utils.toindex(dst)
    if is_multigraph is None:
        is_multigraph = BoolFlag.BOOL_UNKNOWN
    if readonly:
        gidx = _CAPI_DGLGraphCreate(
            src.todgltensor(),
            dst.todgltensor(),
            int(is_multigraph),
            int(num_nodes),
            readonly)
    else:
        if is_multigraph is BoolFlag.BOOL_UNKNOWN:
            # TODO(minjie): better behavior in the future
            is_multigraph = BoolFlag.BOOL_FALSE
        gidx = _CAPI_DGLGraphCreateMutable(bool(is_multigraph))
        gidx.add_nodes(num_nodes)
        gidx.add_edges(src, dst)
    return gidx

def from_csr(indptr, indices, is_multigraph,
             direction, shared_mem_name=""):
    """Load a graph from CSR arrays.

    Parameters
    ----------
    indptr : Tensor
        index pointer in the CSR format
    indices : Tensor
        column index array in the CSR format
    is_multigraph : bool or None
        True if the graph is a multigraph. None means determined by data.
    direction : str
        the edge direction. Either "in" or "out".
    shared_mem_name : str
        the name of shared memory
    """
    indptr = utils.toindex(indptr)
    indices = utils.toindex(indices)
    if is_multigraph is None:
        is_multigraph = BoolFlag.BOOL_UNKNOWN
    gidx = _CAPI_DGLGraphCSRCreate(
        indptr.todgltensor(),
        indices.todgltensor(),
        shared_mem_name,
        int(is_multigraph),
        direction)
    return gidx

def from_shared_mem_csr_matrix(shared_mem_name,
                               num_nodes, num_edges, edge_dir,
                               is_multigraph):
    """Load a graph from the shared memory in the CSR format.

    Parameters
    ----------
    shared_mem_name : string
        the name of shared memory
    num_nodes : int
        the number of nodes
    num_edges : int
        the number of edges
    edge_dir : string
        the edge direction. The supported option is "in" and "out".
    """
    gidx = _CAPI_DGLGraphCSRCreateMMap(
        shared_mem_name,
        int(num_nodes), int(num_edges),
        is_multigraph,
        edge_dir)
    return gidx

def from_networkx(nx_graph, readonly):
    """Convert from networkx graph.

    If 'id' edge attribute exists, the edge will be added follows
    the edge id order. Otherwise, order is undefined.

    Parameters
    ----------
    nx_graph : networkx.DiGraph
        The nx graph or any graph that can be converted to nx.DiGraph
    readonly : bool
        True if the returned graph is readonly.

    Returns
    -------
    GraphIndex
        The graph index.
    """
    if not isinstance(nx_graph, nx.Graph):
        nx_graph = nx.DiGraph(nx_graph)
    else:
        if not nx_graph.is_directed():
            # to_directed creates a deep copy of the networkx graph even if
            # the original graph is already directed and we do not want to do it.
            nx_graph = nx_graph.to_directed()

    is_multigraph = isinstance(nx_graph, nx.MultiDiGraph)
    num_nodes = nx_graph.number_of_nodes()

    # nx_graph.edges(data=True) returns src, dst, attr_dict
    if nx_graph.number_of_edges() > 0:
        has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
    else:
        has_edge_id = False

    if has_edge_id:
        num_edges = nx_graph.number_of_edges()
        src = np.zeros((num_edges,), dtype=np.int64)
        dst = np.zeros((num_edges,), dtype=np.int64)
        for u, v, attr in nx_graph.edges(data=True):
            eid = attr['id']
            src[eid] = u
            dst[eid] = v
    else:
        src = []
        dst = []
        for e in nx_graph.edges:
            src.append(e[0])
            dst.append(e[1])
    num_nodes = nx_graph.number_of_nodes()
    # We store edge Ids as an edge attribute.
    src = utils.toindex(src)
    dst = utils.toindex(dst)
    return from_coo(num_nodes, src, dst, is_multigraph, readonly)

def from_scipy_sparse_matrix(adj, readonly):
    """Convert from scipy sparse matrix.

    Parameters
    ----------
    adj : scipy sparse matrix
    readonly : bool
        True if the returned graph is readonly.

    Returns
    -------
    GraphIndex
        The graph index.
    """
    if adj.getformat() != 'csr' or not readonly:
        num_nodes = max(adj.shape[0], adj.shape[1])
        adj_coo = adj.tocoo()
        return from_coo(num_nodes, adj_coo.row, adj_coo.col, False, readonly)
    else:
        return from_csr(adj.indptr, adj.indices, False, "out")

def from_edge_list(elist, is_multigraph, readonly):
    """Convert from an edge list.

    Parameters
    ---------
    elist : list
        List of (u, v) edge tuple.
    """
    src, dst = zip(*elist)
    src = np.array(src)
    dst = np.array(dst)
    src_ids = utils.toindex(src)
    dst_ids = utils.toindex(dst)
    num_nodes = max(src.max(), dst.max()) + 1
    min_nodes = min(src.min(), dst.min())
    if min_nodes != 0:
        raise DGLError('Invalid edge list. Nodes must start from 0.')
    return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)

def map_to_subgraph_nid(subgraph, parent_nids):
    """Map parent node Ids to the subgraph node Ids.

    Parameters
    ----------
    subgraph: SubgraphIndex
        the graph index of a subgraph

    parent_nids: utils.Index
        Node Ids in the parent graph.

    Returns
    -------
    utils.Index
        Node Ids in the subgraph.
    """
    return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(),
                                                 parent_nids.todgltensor()))

def transform_ids(mapping, ids):
    """Transform ids by the given mapping.

    Parameters
    ----------
    mapping : utils.Index
        The id mapping. new_id = mapping[old_id]
    ids : utils.Index
        The old ids.

    Returns
    -------
    utils.Index
        The new ids.
    """
    return utils.toindex(_CAPI_DGLMapSubgraphNID(
        mapping.todgltensor(), ids.todgltensor()))

def disjoint_union(graphs):
    """Return a disjoint union of the input graphs.

    The new graph will include all the nodes/edges in the given graphs.
    Nodes/Edges will be relabeled by adding the cumsum of the previous graph sizes
    in the given sequence order. For example, giving input [g1, g2, g3], where
    they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
    in the result graph. Edge ids are re-assigned similarly.

    Parameters
    ----------
    graphs : iterable of GraphIndex
        The input graphs

    Returns
    -------
    GraphIndex
        The disjoint union
    """
    return _CAPI_DGLDisjointUnion(list(graphs))

def disjoint_partition(graph, num_or_size_splits):
    """Partition the graph disjointly.

    This is a reverse operation of DisjointUnion. The graph will be partitioned
    into num graphs. This requires the given number of partitions to evenly
    divides the number of nodes in the graph. If the a size list is given,
    the sum of the given sizes is equal.

    Parameters
    ----------
    graph : GraphIndex
        The graph to be partitioned
    num_or_size_splits : int or utils.Index
        The partition number of size splits

    Returns
    -------
    list of GraphIndex
        The partitioned graphs
    """
    if isinstance(num_or_size_splits, utils.Index):
        rst = _CAPI_DGLDisjointPartitionBySizes(
            graph,
            num_or_size_splits.todgltensor())
    else:
        rst = _CAPI_DGLDisjointPartitionByNum(
            graph,
            int(num_or_size_splits))
    return rst

def create_graph_index(graph_data, multigraph, readonly):
    """Create a graph index object.

    Parameters
    ----------
    graph_data : graph data
        Data to initialize graph. Same as networkx's semantics.
    multigraph : bool
        Whether the graph would be a multigraph. If none, the flag will be determined
        by the data.
    readonly : bool
        Whether the graph structure is read-only.
    """
    if isinstance(graph_data, GraphIndex):
        # FIXME(minjie): this return is not correct for mutable graph index
        return graph_data

    if graph_data is None:
        if readonly:
            raise Exception("can't create an empty immutable graph")
        if multigraph is None:
            multigraph = False
        return _CAPI_DGLGraphCreateMutable(multigraph)
    elif isinstance(graph_data, (list, tuple)):
        # edge list
        return from_edge_list(graph_data, multigraph, readonly)
    elif isinstance(graph_data, scipy.sparse.spmatrix):
        # scipy format
        return from_scipy_sparse_matrix(graph_data, readonly)
    else:
        # networkx - any format
        try:
            gidx = from_networkx(graph_data, readonly)
        except Exception:  # pylint: disable=broad-except
            raise DGLError('Error while creating graph from input of type "%s".'
                           % type(graph_data))
        return gidx

#############################################################
# Hetero graph
#############################################################

@register_object('graph.HeteroGraph')
class HeteroGraphIndex(ObjectBase):
    """HeteroGraph index object.

    Note
    ----
    Do not create GraphIndex directly.
    """
    def __new__(cls):
        obj = ObjectBase.__new__(cls)
        return obj

    def __getstate__(self):
        # TODO
        return

    def __setstate__(self, state):
        # TODO
        pass

    @property
    def meta_graph(self):
        """Meta graph

        Returns
        -------
        GraphIndex
            The meta graph.
        """
        return _CAPI_DGLHeteroGetMetaGraph(self)

    def number_of_ntypes(self):
        """Return number of node types."""
        return self.meta_graph.number_of_nodes()

    def number_of_etypes(self):
        """Return number of edge types."""
        return self.meta_graph.number_of_edges()

    def get_relation_graph(self, etype):
        """Get the bipartite graph of the given edge/relation type.

        Parameters
        ----------
        etype : int
            The edge/relation type.

        Returns
        -------
        HeteroGraphIndex
            The bipartite graph.
        """
        return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))

    def add_nodes(self, ntype, num):
        """Add nodes.

        Parameters
        ----------
        ntype : int
            Node type
        num : int
            Number of nodes to be added.
        """
        _CAPI_DGLHetero(self, int(ntype), int(num))

    def add_edge(self, etype, u, v):
        """Add one edge.

        Parameters
        ----------
        etype : int
            Edge type
        u : int
            The src node.
        v : int
            The dst node.
        """
        _CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))

    def add_edges(self, etype, u, v):
        """Add many edges.

        Parameters
        ----------
        etype : int
            Edge type
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.
        """
        _CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor())

    def clear(self):
        """Clear the graph."""
        _CAPI_DGLHeteroClear(self)

    def ctx(self):
        """Return the context of this graph index.

        Returns
        -------
        DGLContext
            The context of the graph.
        """
        return _CAPI_DGLHeteroContext(self)

    def nbits(self):
        """Return the number of integer bits used in the storage (32 or 64).

        Returns
        -------
        int
            The number of bits.
        """
        return _CAPI_DGLHeteroNumBits(self)

    def is_multigraph(self):
        """Return whether the graph is a multigraph

        Returns
        -------
        bool
            True if it is a multigraph, False otherwise.
        """
        return bool(_CAPI_DGLHeteroIsMultigraph(self))

    def is_readonly(self):
        """Return whether the graph index is read-only.

        Returns
        -------
        bool
            True if it is a read-only graph, False otherwise.
        """
        return bool(_CAPI_DGLHeteroIsReadonly(self))

    def number_of_nodes(self, ntype):
        """Return the number of nodes.

        Parameters
        ----------
        ntype : int
            Node type

        Returns
        -------
        int
            The number of nodes
        """
        return _CAPI_DGLHeteroNumVertices(self, int(ntype))

    def number_of_edges(self, etype):
        """Return the number of edges.

        Parameters
        ----------
        etype : int
            Edge type

        Returns
        -------
        int
            The number of edges
        """
        return _CAPI_DGLHeteroNumEdges(self, int(etype))

    def has_node(self, ntype, vid):
        """Return true if the node exists.

        Parameters
        ----------
        ntype : int
            Node type
        vid : int
            The nodes

        Returns
        -------
        bool
            True if the node exists, False otherwise.
        """
        return bool(_CAPI_DGLHeteroHasVertex(self, int(ntype), int(vid)))

    def has_nodes(self, ntype, vids):
        """Return true if the nodes exist.

        Parameters
        ----------
        ntype : int
            Node type
        vid : utils.Index
            The nodes

        Returns
        -------
        utils.Index
            0-1 array indicating existence
        """
        vid_array = vids.todgltensor()
        return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array))

    def has_edge_between(self, etype, u, v):
        """Return true if the edge exists.

        Parameters
        ----------
        etype : int
            Edge type
        u : int
            The src node.
        v : int
            The dst node.

        Returns
        -------
        bool
            True if the edge exists, False otherwise
        """
        return bool(_CAPI_DGLHeteroHasEdgeBetween(self, int(etype), int(u), int(v)))

    def has_edges_between(self, etype, u, v):
        """Return true if the edge exists.

        Parameters
        ----------
        etype : int
            Edge type
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.

        Returns
        -------
        utils.Index
            0-1 array indicating existence
        """
        u_array = u.todgltensor()
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
            self, int(etype), u_array, v_array))

    def predecessors(self, etype, v):
        """Return the predecessors of the node.

        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : int
            The node.

        Returns
        -------
        utils.Index
            Array of predecessors
        """
        return utils.toindex(_CAPI_DGLHeteroPredecessors(
            self, int(etype), int(v)))

    def successors(self, etype, v):
        """Return the successors of the node.

        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : int
            The node.

        Returns
        -------
        utils.Index
            Array of successors
        """
        return utils.toindex(_CAPI_DGLHeteroSuccessors(
            self, int(etype), int(v)))

    def edge_id(self, etype, u, v):
        """Return the id array of all edges between u and v.

        Parameters
        ----------
        etype : int
            Edge type
        u : int
            The src node.
        v : int
            The dst node.

        Returns
        -------
        utils.Index
            The edge id array.
        """
        return utils.toindex(_CAPI_DGLHeteroEdgeId(
            self, int(etype), int(u), int(v)))

    def edge_ids(self, etype, u, v):
        """Return a triplet of arrays that contains the edge IDs.

        Parameters
        ----------
        etype : int
            Edge type
        u : utils.Index
            The src nodes.
        v : utils.Index
            The dst nodes.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        u_array = u.todgltensor()
        v_array = v.todgltensor()
        edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)

        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))

        return src, dst, eid

    def find_edges(self, etype, eid):
        """Return a triplet of arrays that contains the edge IDs.

        Parameters
        ----------
        etype : int
            Edge type
        eid : utils.Index
            The edge ids.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        eid_array = eid.todgltensor()
        edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)

        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))

        return src, dst, eid

    def in_edges(self, etype, v):
        """Return the in edges of the node(s).

        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : utils.Index
            The node(s).

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if len(v) == 1:
            edge_array = _CAPI_DGLHeteroInEdges_1(self, int(etype), int(v[0]))
        else:
            v_array = v.todgltensor()
            edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))
        return src, dst, eid

    def out_edges(self, etype, v):
        """Return the out edges of the node(s).

        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : utils.Index
            The node(s).

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if len(v) == 1:
            edge_array = _CAPI_DGLHeteroOutEdges_1(self, int(etype), int(v[0]))
        else:
            v_array = v.todgltensor()
            edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
        src = utils.toindex(edge_array(0))
        dst = utils.toindex(edge_array(1))
        eid = utils.toindex(edge_array(2))
        return src, dst, eid

    def edges(self, etype, order=None):
        """Return all the edges

        Parameters
        ----------
        etype : int
            Edge type
        order : string
            The order of the returned edges. Currently support:

            - 'srcdst' : sorted by their src and dst ids.
            - 'eid'    : sorted by edge Ids.
            - None     : the arbitrary order.

        Returns
        -------
        utils.Index
            The src nodes.
        utils.Index
            The dst nodes.
        utils.Index
            The edge ids.
        """
        if order is None:
            order = ""
        edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
        src = edge_array(0)
        dst = edge_array(1)
        eid = edge_array(2)
        src = utils.toindex(src)
        dst = utils.toindex(dst)
        eid = utils.toindex(eid)
        return src, dst, eid

    def in_degree(self, etype, v):
        """Return the in degree of the node.

        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : int
            The node.

        Returns
        -------
        int
            The in degree.
        """
        return _CAPI_DGLHeteroInDegree(self, int(etype), int(v))

    def in_degrees(self, etype, v):
        """Return the in degrees of the nodes.

        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : utils.Index
            The nodes.

        Returns
        -------
        int
            The in degree array.
        """
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array))

    def out_degree(self, etype, v):
        """Return the out degree of the node.

        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : int
            The node.

        Returns
        -------
        int
            The out degree.
        """
        return _CAPI_DGLHeteroOutDegree(self, int(etype), int(v))

    def out_degrees(self, etype, v):
        """Return the out degrees of the nodes.

        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.

        Parameters
        ----------
        etype : int
            Edge type
        v : utils.Index
            The nodes.

        Returns
        -------
        int
            The out degree array.
        """
        v_array = v.todgltensor()
        return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array))

    def adjacency_matrix(self, etype, transpose, ctx):
        """Return the adjacency matrix representation of this graph.

        By default, a row of returned adjacency matrix represents the destination
        of an edge and the column represents the source.

        When transpose is True, a row represents the source and a column represents
        a destination.

        Parameters
        ----------
        etype : int
            Edge type
        transpose : bool
            A flag to transpose the returned adjacency matrix.
        ctx : context
            The context of the returned matrix.

        Returns
        -------
        SparseTensor
            The adjacency matrix.
        utils.Index
            A index for data shuffling due to sparse format change. Return None
            if shuffle is not required.
        """
        if not isinstance(transpose, bool):
            raise DGLError('Expect bool value for "transpose" arg,'
                           ' but got %s.' % (type(transpose)))
        fmt = F.get_preferred_sparse_format()
        rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
        # convert to framework-specific sparse matrix
        srctype, dsttype = self.meta_graph.find_edge(etype)
        nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
        ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
        nnz = self.number_of_edges(etype)
        if fmt == "csr":
            indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
            indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
            shuffle = utils.toindex(rst(2))
            dat = F.ones(nnz, dtype=F.float32, ctx=ctx)  # FIXME(minjie): data type
            spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
            return spmat, shuffle
        elif fmt == "coo":
            idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
            idx = F.reshape(idx, (2, nnz))
            dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
            adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols))
            shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
            return adj, shuffle_idx
        else:
            raise Exception("unknown format")

    def node_subgraph(self, induced_nodes):
        """Return the induced node subgraph.

        Parameters
        ----------
        induced_nodes : list of utils.Index
            Induced nodes. The length should be equal to the number of
            node types in this heterograph.

        Returns
        -------
        SubgraphIndex
            The subgraph index.
        """
        vids = [nodes.todgltensor() for nodes in induced_nodes]
        return _CAPI_DGLHeteroVertexSubgraph(self, vids)

    def edge_subgraph(self, induced_edges, preserve_nodes):
        """Return the induced edge subgraph.

        Parameters
        ----------
        induced_edges : list of utils.Index
            Induced edges. The length should be equal to the number of
            edge types in this heterograph.
        preserve_nodes : bool
            Indicates whether to preserve all nodes or not.
            If true, keep the nodes which have no edge connected in the subgraph;
            If false, all nodes without edge connected to it would be removed.

        Returns
        -------
        SubgraphIndex
            The subgraph index.
        """
        eids = [edges.todgltensor() for edges in induced_edges]
        return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)

@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
    """Hetero-subgraph data structure"""
    @property
    def graph(self):
        """The subgraph structure

        Returns
        -------
        HeteroGraphIndex
            The subgraph
        """
        return _CAPI_DGLHeteroSubgraphGetGraph(self)

    @property
    def induced_nodes(self):
        """Induced nodes for each node type. The return list
        length should be equal to the number of node types.

        Returns
        -------
        list of utils.Index
            Induced nodes
        """
        ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
        return [utils.toindex(v.data) for v in ret]

    @property
    def induced_edges(self):
        """Induced edges for each edge type. The return list
        length should be equal to the number of edge types.

        Returns
        -------
        list of utils.Index
            Induced edges
        """
        ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
        return [utils.toindex(v.data) for v in ret]

def create_bipartite_from_coo(num_src, num_dst, row, col):
    """Create a bipartite graph index from COO format

    Parameters
    ----------
    num_src : int
        Number of nodes in the src type.
    num_dst : int
        Number of nodes in the dst type.
    row : utils.Index
        Row index.
    col : utils.Index
        Col index.

    Returns
    -------
    HeteroGraphIndex
    """
    return _CAPI_DGLHeteroCreateBipartiteFromCOO(
        int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())

def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
    """Create a bipartite graph index from CSR format

    Parameters
    ----------
    num_src : int
        Number of nodes in the src type.
    num_dst : int
        Number of nodes in the dst type.
    indptr : utils.Index
        CSR indptr.
    indices : utils.Index
        CSR indices.
    edge_ids : utils.Index
        Edge shuffle id.

    Returns
    -------
    HeteroGraphIndex
    """
    return _CAPI_DGLHeteroCreateBipartiteFromCSR(
        int(num_src), int(num_dst),
        indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())

def create_heterograph(meta_graph, rel_graphs):
    """Create a heterograph from metagraph and graphs of every relation.

    Parameters
    ----------
    meta_graph : GraphIndex
        Meta-graph.
    rel_graphs : list of HeteroGraphIndex
        Bipartite graph of each relation.

    Returns
    -------
    HeteroGraphIndex
    """
    return _CAPI_DGLHeteroCreateHeteroGraph(meta_graph, rel_graphs)

_init_api("dgl.graph_index")
