Commit 52d4535b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Hetero][RFC] Heterogeneous graph Python interfaces & Message Passing (#752)

* moving heterograph index to another file

* node view

* python interfaces

* heterograph init

* bug fixes

* docstring for readonly

* more docstring

* unit tests & lint

* oops

* oops x2

* removed node/edge addition

* addressed comments

* lint

* rw on frames with one node/edge type

* homograph with underlying heterograph demo

* view is not necessary

* bugfix

* replace

* scheduler, builtins not working yet

* moving bipartite.h to header

* moving back bipartite to bipartite.h

* oops

* asbits and copyto for bipartite

* tested update_all and send_and_recv

* lightweight node & edge type retrieval

* oops

* sorry

* removing obsolete code

* oops

* lint

* various bug fixes & more tests

* UDF tests

* multiple type number_of_nodes and number_of_edges

* docstring fixes

* more tests

* going for dict in initialization

* lint

* updated api as per discussions

* lint

* bug

* bugfix

* moving back bipartite impl to cc

* note on views

* fix
parent 66971c1a
......@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index);
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
......
......@@ -17,6 +17,7 @@ from .base import ALL
from .backend import load_backend
from .batched_graph import *
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import *
from .traversal import *
from .transform import *
......
......@@ -49,6 +49,14 @@ class DGLBaseGraph(object):
"""
return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
......@@ -65,6 +73,10 @@ class DGLBaseGraph(object):
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self):
"""Return the number of edges in the graph.
......@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index):
self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......
......@@ -1269,737 +1269,4 @@ def create_graph_index(graph_data, multigraph, readonly):
% type(graph_data))
return gidx
#############################################################
# Hetero graph
#############################################################
@register_object('graph.HeteroGraph')
class HeteroGraphIndex(ObjectBase):
"""HeteroGraph index object.
Note
----
Do not create GraphIndex directly.
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
return obj
def __getstate__(self):
# TODO
return
def __setstate__(self, state):
# TODO
pass
@property
def meta_graph(self):
"""Meta graph
Returns
-------
GraphIndex
The meta graph.
"""
return _CAPI_DGLHeteroGetMetaGraph(self)
def number_of_ntypes(self):
"""Return number of node types."""
return self.meta_graph.number_of_nodes()
def number_of_etypes(self):
"""Return number of edge types."""
return self.meta_graph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
HeteroGraphIndex
The bipartite graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def add_nodes(self, ntype, num):
"""Add nodes.
Parameters
----------
ntype : int
Node type
num : int
Number of nodes to be added.
"""
_CAPI_DGLHetero(self, int(ntype), int(num))
def add_edge(self, etype, u, v):
"""Add one edge.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))
def add_edges(self, etype, u, v):
"""Add many edges.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
_CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor())
def clear(self):
"""Clear the graph."""
_CAPI_DGLHeteroClear(self)
def ctx(self):
"""Return the context of this graph index.
Returns
-------
DGLContext
The context of the graph.
"""
return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsMultigraph(self))
def is_readonly(self):
"""Return whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsReadonly(self))
def number_of_nodes(self, ntype):
"""Return the number of nodes.
Parameters
----------
ntype : int
Node type
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))
def number_of_edges(self, etype):
"""Return the number of edges.
Parameters
----------
etype : int
Edge type
Returns
-------
int
The number of edges
"""
return _CAPI_DGLHeteroNumEdges(self, int(etype))
def has_node(self, ntype, vid):
"""Return true if the node exists.
Parameters
----------
ntype : int
Node type
vid : int
The nodes
Returns
-------
bool
True if the node exists, False otherwise.
"""
return bool(_CAPI_DGLHeteroHasVertex(self, int(ntype), int(vid)))
def has_nodes(self, ntype, vids):
"""Return true if the nodes exist.
Parameters
----------
ntype : int
Node type
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array))
def has_edge_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists, False otherwise
"""
return bool(_CAPI_DGLHeteroHasEdgeBetween(self, int(etype), int(u), int(v)))
def has_edges_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array))
def predecessors(self, etype, v):
"""Return the predecessors of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v)))
def successors(self, etype, v):
"""Return the successors of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v)))
def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
utils.Index
The edge id array.
"""
return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v)))
def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, etype, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, etype, v):
"""Return the in edges of the node(s).
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroInEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, etype, v):
"""Return the out edges of the node(s).
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroOutEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def edges(self, etype, order=None):
"""Return all the edges
Parameters
----------
etype : int
Edge type
order : string
The order of the returned edges. Currently support:
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if order is None:
order = ""
edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
src = utils.toindex(src)
dst = utils.toindex(dst)
eid = utils.toindex(eid)
return src, dst, eid
def in_degree(self, etype, v):
"""Return the in degree of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLHeteroInDegree(self, int(etype), int(v))
def in_degrees(self, etype, v):
"""Return the in degrees of the nodes.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array))
def out_degree(self, etype, v):
"""Return the out degree of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLHeteroOutDegree(self, int(etype), int(v))
def out_degrees(self, etype, v):
"""Return the out degrees of the nodes.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array))
def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
ctx : context
The context of the returned matrix.
Returns
-------
SparseTensor
The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
# convert to framework-specific sparse matrix
srctype, dsttype = self.meta_graph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle
elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
else:
raise Exception("unknown format")
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
Parameters
----------
induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of
node types in this heterograph.
Returns
-------
SubgraphIndex
The subgraph index.
"""
vids = [nodes.todgltensor() for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids)
def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph.
Parameters
----------
induced_edges : list of utils.Index
Induced edges. The length should be equal to the number of
edge types in this heterograph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
SubgraphIndex
The subgraph index.
"""
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
"""Hetero-subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Returns
-------
HeteroGraphIndex
The subgraph
"""
return _CAPI_DGLHeteroSubgraphGetGraph(self)
@property
def induced_nodes(self):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
Returns
-------
list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret]
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
indptr : utils.Index
CSR indptr.
indices : utils.Index
CSR indices.
edge_ids : utils.Index
Edge shuffle id.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(meta_graph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
----------
meta_graph : GraphIndex
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(meta_graph, rel_graphs)
_init_api("dgl.graph_index")
"""Classes for heterogeneous graphs."""
from collections import defaultdict
import networkx as nx
import scipy.sparse as ssp
from . import heterograph_index, graph_index
from . import utils
from . import backend as F
from . import init
from .runtime import ir, scheduler, Runtime
from .frame import Frame, FrameRef
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, is_all, DGLError
__all__ = ['DGLHeteroGraph']
# TODO: depending on the progress of unifying DGLGraph and Bipartite, we may or may not
# need the code of heterogeneous graph views.
# pylint: disable=unnecessary-pass
class DGLBaseHeteroGraph(object):
"""Base Heterogeneous graph class.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Parameters
----------
metagraph : NetworkX MultiGraph or compatible data structure
The set of node types and edge types, as well as the
source/destination node type of each edge type is specified in the
metagraph.
The edge types are specified as edge keys on the NetworkX MultiGraph.
The node types and edge types must be strings.
number_of_nodes_by_type : dict[str, int]
Number of nodes for each node type.
edge_connections_by_type : dict
Specifies how edges would connect nodes of the source type to nodes of
the destination type in the following form:
{edge_type: edge_specifier}
where ``edge_type`` is a triplet of
(source_node_type_name,
destination_node_type_name,
edge_type_name)
and ``edge_specifier`` can be either of the following:
* (source_node_id_tensor, destination_node_id_tensor)
* ``source_node_id_tensor`` and ``destination_node_id_tensor`` are
IDs within the source and destination node type respectively.
* source node id and destination node id are both in their own ID space.
That is, source nodes and destination nodes may have the same ID,
but they are different nodes if they belong to different node types.
* scipy.sparse.matrix
By default, the rows represent the destination of an edge, and the
column represents the source.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> import networkx as nx
>>> metagraph = nx.MultiGraph([
... ('user', 'user', 'follows'),
... ('user', 'game', 'plays'),
... ('developer', 'game', 'develops')])
>>> g = DGLBaseHeteroGraph(
... metagraph=metagraph,
... number_of_nodes_by_type={'user': 4, 'game': 2, 'developer': 2},
... edge_connections_by_type={
... # Alice follows Bob and Bob follows Carol
... ('user', 'user', 'follows'): ([0, 1], [1, 2]),
... # Alice and Bob play Tetris and Bob and Carol play Minecraft
... ('user', 'game', 'plays'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... # Nintendo develops Tetris and Mojang develops Minecraft
... ('developer', 'game', 'develops'): ([0, 1], [0, 1])})
graph : graph index, optional
The graph index
ntypes : list[str]
The node type names
etypes : list[str]
The edge type names
_ntypes_invmap, _etypes_invmap, _view_ntype_idx, _view_etype_idx :
Internal arguments
"""
# pylint: disable=unused-argument
def __init__(
self,
metagraph,
number_of_nodes_by_type,
edge_connections_by_type):
def __init__(self, graph, ntypes, etypes,
_ntypes_invmap=None, _etypes_invmap=None,
_view_ntype_idx=None, _view_etype_idx=None):
super(DGLBaseHeteroGraph, self).__init__()
def __getitem__(self, key):
"""Returns a view on the heterogeneous graph with given node/edge
type:
self._graph = graph
self._ntypes = ntypes
self._etypes = etypes
# inverse mapping from ntype str to int
self._ntypes_invmap = _ntypes_invmap or \
{ntype: i for i, ntype in enumerate(ntypes)}
# inverse mapping from etype str to int
self._etypes_invmap = _etypes_invmap or \
{etype: i for i, etype in enumerate(etypes)}
* If ``key`` is a str, it returns a heterogeneous subgraph induced
from nodes of type ``key``.
* If ``key`` is a pair of str (type_A, type_B), it returns a
heterogeneous subgraph induced from the union of both node types.
* If ``key`` is a triplet of str
# Indicates which node/edge type (int) it is viewing.
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
(src_type_name, dst_type_name, edge_type_name)
self._cache = {}
It returns a heterogeneous subgraph induced from the edges with
source type name ``src_type_name``, destination type name
``dst_type_name``, and edge type name ``edge_type_name``.
def _create_view(self, ntype_idx, etype_idx):
return DGLBaseHeteroGraph(
self._graph, self._ntypes, self._etypes,
self._ntypes_invmap, self._etypes_invmap,
ntype_idx, etype_idx)
The view would share the frames with the parent graph; any
modifications on one's frames would reflect on the other.
Note that the subgraph itself is not materialized until someone
queries the subgraph structure. This implies that calling computation
methods such as
@property
def is_node_type_view(self):
"""Whether this is a node type view of a heterograph."""
return self._view_ntype_idx is not None
g['user'].update_all(...)
@property
def is_edge_type_view(self):
"""Whether this is an edge type view of a heterograph."""
return self._view_etype_idx is not None
would not create a subgraph of users.
@property
def is_view(self):
"""Whether this is a node/view of a heterograph."""
return self.is_node_type_view or self.is_edge_type_view
Parameters
----------
key : str or tuple
See above
@property
def all_node_types(self):
"""Return the list of node types of the entire heterograph."""
return self._ntypes
Returns
-------
DGLBaseHeteroGraphView
The induced subgraph view.
"""
pass
@property
def all_edge_types(self):
"""Return the list of edge types of the entire heterograph."""
return self._etypes
@property
def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph."""
pass
def number_of_nodes(self):
"""Return the number of nodes in the graph.
"""Return the metagraph as networkx.MultiDiGraph.
The nodes are labeled with node type names.
The edges have their keys holding the edge type names.
"""
nx_graph = self._graph.metagraph.to_networkx()
nx_return_graph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
etype = self._etypes[nx_graph.edges[u_v]['id']]
srctype = self._ntypes[u_v[0]]
dsttype = self._ntypes[u_v[1]]
assert etype[0] == srctype
assert etype[2] == dsttype
nx_return_graph.add_edge(srctype, dsttype, etype[1])
return nx_return_graph
def _endpoint_types(self, etype):
"""Return the source and destination node type (int) of given edge
type (int)."""
return self._graph.metagraph.find_edge(etype)
def _node_types(self):
if self.is_node_type_view:
return [self._view_ntype_idx]
elif self.is_edge_type_view:
srctype_idx, dsttype_idx = self._endpoint_types(self._view_etype_idx)
return [srctype_idx, dsttype_idx] if srctype_idx != dsttype_idx else [srctype_idx]
else:
return range(len(self._ntypes))
def node_types(self):
"""Return the list of node types appearing in the current view.
Returns
-------
int
The number of nodes
"""
pass
list[str]
List of node types
def __len__(self):
"""Return the number of nodes in the graph."""
pass
# TODO: REVIEW
def add_nodes(self, num, node_type, data=None):
"""Add multiple new nodes of the same node type
Examples
--------
Getting all node types.
>>> g.node_types()
['user', 'game', 'developer']
Getting all node types appearing in the subgraph induced by "users"
(which should only yield "user").
>>> g['user'].node_types()
['user']
The node types appearing in subgraph induced by "plays" relationship,
which should only give "user" and "game".
>>> g['plays'].node_types()
['user', 'game']
"""
ntypes = self._node_types()
if isinstance(ntypes, range):
# assuming that the range object always covers the entire node type list
return self._ntypes
else:
return [self._ntypes[i] for i in ntypes]
def _edge_types(self):
if self.is_node_type_view:
etype_indices = self._graph.metagraph.edge_id(
self._view_ntype_idx, self._view_ntype_idx)
return etype_indices
elif self.is_edge_type_view:
return [self._view_etype_idx]
else:
return range(len(self._etypes))
def edge_types(self):
"""Return the list of edge types appearing in the current view.
Parameters
----------
num : int
Number of nodes to be added.
node_type : str
Type of the added nodes. Must appear in the metagraph.
data : dict, optional
Feature data of the added nodes.
Returns
-------
list[str]
List of edge types
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
Getting all edge types.
>>> g.edge_types()
['follows', 'plays', 'develops']
>>> g['game'].number_of_nodes()
2
>>> g.add_nodes(3, 'game') # add 3 new games
>>> g['game'].number_of_nodes()
5
Getting all edge types appearing in subgraph induced by "users".
>>> g['user'].edge_types()
['follows']
The edge types appearing in subgraph induced by "plays" relationship,
which should only give "plays".
>>> g['plays'].edge_types()
['plays']
"""
pass
etypes = self._edge_types()
if isinstance(etypes, range):
return self._etypes
else:
return [self._etypes[i] for i in etypes]
@property
@utils.cached_member('_cache', '_current_ntype_idx')
def _current_ntype_idx(self):
"""Checks the uniqueness of node type in the view and get the index
of that node type.
# TODO: REVIEW
def add_edge(self, u, v, utype, vtype, etype, data=None):
"""Add an edge of ``etype`` between u of type ``utype`` and v of type
``vtype``.
This allows reading/writing node frame data.
"""
node_types = self._node_types()
assert len(node_types) == 1, "only available for subgraphs with one node type"
return node_types[0]
Parameters
----------
u : int
The source node ID of type ``utype``. Must exist in the graph.
v : int
The destination node ID of type ``vtype``. Must exist in the
graph.
utype : str
The source node type name. Must exist in the metagraph.
vtype : str
The destination node type name. Must exist in the metagraph.
etype : str
The edge type name. Must exist in the metagraph.
data : dict, optional
Feature data of the added edge.
@property
@utils.cached_member('_cache', '_current_etype_idx')
def _current_etype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of that edge type.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
This allows reading/writing edge frame data and message passing routines.
"""
edge_types = self._edge_types()
assert len(edge_types) == 1, "only available for subgraphs with one edge type"
return edge_types[0]
>>> g['user', 'game', 'plays'].number_of_edges()
4
>>> g.add_edge(2, 0, 'user', 'game', 'plays')
>>> g['user', 'game', 'plays'].number_of_edges()
5
@property
@utils.cached_member('_cache', '_current_srctype_idx')
def _current_srctype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of the source type.
This allows reading/writing edge frame data and message passing routines.
"""
pass
srctype_idx, _ = self._endpoint_types(self._current_etype_idx)
return srctype_idx
def add_edges(self, u, v, utype, vtype, etype, data=None):
"""Add multiple edges of ``etype`` between list of source nodes ``u``
of type ``utype`` and list of destination nodes ``v`` of type
``vtype``. A single edge is added between every pair of ``u[i]`` and
``v[i]``.
@property
@utils.cached_member('_cache', '_current_dsttype_idx')
def _current_dsttype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of the destination type.
This allows reading/writing edge frame data and message passing routines.
"""
_, dsttype_idx = self._endpoint_types(self._current_etype_idx)
return dsttype_idx
def number_of_nodes(self, ntype):
"""Return the number of nodes of the given type in the heterograph.
Parameters
----------
u : list, tensor
The source node IDs of type ``utype``. Must exist in the graph.
v : list, tensor
The destination node IDs of type ``vtype``. Must exist in the
graph.
utype : str
The source node type name. Must exist in the metagraph.
vtype : str
The destination node type name. Must exist in the metagraph.
etype : str
The edge type name. Must exist in the metagraph.
data : dict, optional
Feature data of the added edge.
ntype : str
The node type
Returns
-------
int
The number of nodes
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['user', 'game', 'plays'].number_of_edges()
4
>>> g.add_edges([0, 2], [1, 0], 'user', 'game', 'plays')
>>> g['user', 'game', 'plays'].number_of_edges()
6
>>> g['user'].number_of_nodes()
3
"""
pass
return self._graph.number_of_nodes(self._ntypes_invmap[ntype])
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_srctype_idx)
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_dsttype_idx)
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise.
"""
pass
assert not self.is_view, 'not supported on views'
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
pass
return self._graph.is_readonly()
def number_of_edges(self):
"""Return the number of edges in the graph.
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self._graph.number_of_edges(self._current_etype_idx)
def number_of_edges(self, etype):
"""Return the number of edges of the given type in the heterograph.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
Returns
-------
int
The number of edges
"""
pass
def has_node(self, vid):
"""Return True if the graph contains node `vid`.
Only works if the graph has one node type. For multiple types,
query with
.. code::
Examples
--------
>>> g.number_of_edges(('user', 'plays', 'game'))
4
"""
return self._graph.number_of_edges(self._etypes_invmap[etype])
g['vtype'].has_node(vid)
def has_node(self, ntype, vid):
"""Return True if the graph contains node `vid` of type `ntype`.
Parameters
----------
ntype : str
The node type.
vid : int
The node ID.
......@@ -311,53 +314,26 @@ class DGLBaseHeteroGraph(object):
Examples
--------
>>> g['user'].has_node(0)
>>> g.has_node('user', 0)
True
>>> g['user'].has_node(4)
>>> g.has_node('user', 4)
False
Equivalently,
>>> 0 in g['user']
True
See Also
--------
has_nodes
"""
pass
def __contains__(self, vid):
"""Return True if the graph contains node `vid`.
Only works if the graph has one node type. For multiple types,
query with
.. code::
vid in g['vtype']
Examples
--------
>>> 0 in g['user']
True
"""
pass
return self._graph.has_node(self._ntypes_invmap[ntype], vid)
def has_nodes(self, vids):
def has_nodes(self, ntype, vids):
"""Return a 0-1 array ``a`` given the node ID array ``vids``.
``a[i]`` is 1 if the graph contains node ``vids[i]``, 0 otherwise.
Only works if the graph has one node type. For multiple types,
query with
.. code::
g['vtype'].has_nodes(vids)
``a[i]`` is 1 if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Parameters
----------
ntype : str
The node type.
vid : list or tensor
The array of node IDs.
......@@ -370,27 +346,24 @@ class DGLBaseHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g['user'].has_nodes([0, 1, 2, 3, 4])
>>> g.has_nodes('user', [0, 1, 2, 3, 4])
tensor([1, 1, 1, 0, 0])
See Also
--------
has_node
"""
pass
def has_edge_between(self, u, v):
"""Return True if the edge (u, v) is in the graph.
Only works if the graph has one edge type. For multiple types,
query with
vids = utils.toindex(vids)
rst = self._graph.has_nodes(self._ntypes_invmap[ntype], vids)
return rst.tousertensor()
.. code::
g['srctype', 'dsttype', 'edgetype'].has_edge_between(u, v)
def has_edge_between(self, etype, u, v):
"""Return True if the edge (u, v) of type ``etype`` is in the graph.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The node ID of source type.
v : int
......@@ -404,34 +377,29 @@ class DGLBaseHeteroGraph(object):
Examples
--------
Check whether Alice plays Tetris
>>> g['user', 'game', 'plays'].has_edge_between(0, 1)
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 1)
True
And whether Alice plays Minecraft
>>> g['user', 'game', 'plays'].has_edge_between(0, 2)
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 2)
False
See Also
--------
has_edges_between
"""
pass
return self._graph.has_edge_between(self._etypes_invmap[etype], u, v)
def has_edges_between(self, u, v):
"""Return a 0-1 array `a` given the source node ID array `u` and
destination node ID array `v`.
def has_edges_between(self, etype, u, v):
"""Return a 0-1 array ``a`` given the source node ID array ``u`` and
destination node ID array ``v``.
`a[i]` is 1 if the graph contains edge `(u[i], v[i])`, 0 otherwise.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].has_edges_between(u, v)
``a[i]`` is 1 if the graph contains edge ``(u[i], v[i])`` of type ``etype``, 0 otherwise.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor
The node ID array of source type.
v : list, tensor
......@@ -446,31 +414,29 @@ class DGLBaseHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g['user', 'game', 'plays'].has_edges_between([0, 0], [1, 2])
>>> g.has_edges_between(('user', 'plays', 'game'), [0, 0], [1, 2])
tensor([1, 0])
See Also
--------
has_edge_between
"""
pass
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(self._etypes_invmap[etype], u, v)
return rst.tousertensor()
def predecessors(self, v):
def predecessors(self, etype, v):
"""Return the predecessors of node `v` in the graph with the same
edge type.
Node `u` is a predecessor of `v` if an edge `(u, v)` exist in the
graph.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].predecessors(v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node of destination type.
......@@ -484,7 +450,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Query who plays Tetris:
>>> g['user', 'game', 'plays'].predecessors(0)
>>> g.predecessors(('user', 'plays', 'game'), 0)
tensor([0, 1])
This indicates User #0 (Alice) and User #1 (Bob).
......@@ -493,38 +459,33 @@ class DGLBaseHeteroGraph(object):
--------
successors
"""
pass
return self._graph.predecessors(self._etypes_invmap[etype], v).tousertensor()
def successors(self, v):
def successors(self, etype, v):
"""Return the successors of node `v` in the graph with the same edge
type.
Node `u` is a successor of `v` if an edge `(v, u)` exist in the
graph.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].successors(v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node of source type.
Returns
-------
tensor
Array of successor node IDs if destination node type.
Array of successor node IDs of destination node type.
Examples
--------
The following example uses PyTorch backend.
Asks which game Alice plays:
>>> g['user', 'game', 'plays'].successors(0)
>>> g.successors(('user', 'plays', 'game'), 0)
tensor([0])
This indicates Game #0 (Tetris).
......@@ -533,21 +494,19 @@ class DGLBaseHeteroGraph(object):
--------
predecessors
"""
pass
return self._graph.successors(self._etypes_invmap[etype], v).tousertensor()
def edge_id(self, u, v, force_multi=False):
def edge_id(self, etype, u, v, force_multi=False):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_id(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The node ID of source type.
v : int
......@@ -567,28 +526,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the edge ID of "Bob plays Tetris"
>>> g['user', 'game', 'plays'].edge_id(1, 0)
>>> g.edge_id(('user', 'plays', 'game'), 1, 0)
1
See Also
--------
edge_ids
"""
pass
idx = self._graph.edge_id(self._etypes_invmap[etype], u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
def edge_ids(self, u, v, force_multi=False):
def edge_ids(self, etype, u, v, force_multi=False):
"""Return all edge IDs between source node array `u` and destination
node array `v`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor
The node ID array of source type.
v : list, tensor
......@@ -616,29 +574,30 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft".
>>> g['user', 'game', 'plays'].edge_ids([0, 1], [0, 1])
>>> g.edge_ids(('user', 'plays', 'game'), [0, 1], [0, 1])
tensor([0, 2])
See Also
--------
edge_id
"""
pass
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self._etypes_invmap[etype], u, v)
if force_multi or self._graph.is_multigraph():
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
def find_edges(self, eid):
def find_edges(self, etype, eid):
"""Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
eid : list, tensor
The edge ID array.
......@@ -654,23 +613,20 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the user and game of gameplay #0 and #2:
>>> g['user', 'game', 'plays'].find_edges([0, 2])
>>> g.find_edges(('user', 'plays', 'game'), [0, 2])
(tensor([0, 1]), tensor([0, 1]))
"""
pass
eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(self._etypes_invmap[etype], eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, v, form='uv'):
def in_edges(self, etype, v, form='uv'):
"""Return the inbound edges of the node(s).
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor
The node(s) of destination type.
form : str, optional
......@@ -696,23 +652,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the gameplay IDs of game #0 (Tetris)
>>> g['user', 'game', 'plays'].in_edges(0, 'eid')
>>> g.in_edges(('user', 'plays', 'game'), 0, 'eid')
tensor([0, 1])
"""
pass
def out_edges(self, v, form='uv'):
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(self._etypes_invmap[etype], v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def out_edges(self, etype, v, form='uv'):
"""Return the outbound edges of the node(s).
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor
The node(s) of source type.
form : str, optional
......@@ -738,23 +698,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the gameplay IDs of user #0 (Alice)
>>> g['user', 'game', 'plays'].out_edges(0, 'eid')
>>> g.out_edges(('user', 'plays', 'game'), 0, 'eid')
tensor([0])
"""
pass
def all_edges(self, form='uv', order=None):
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(self._etypes_invmap[etype], v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def all_edges(self, etype, form='uv', order=None):
"""Return all the edges.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
form : str, optional
The return form. Currently support:
......@@ -785,57 +749,57 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the user-game pairs for all gameplays:
>>> g['user', 'game', 'plays'].all_edges('uv')
>>> g.all_edges(('user', 'plays', 'game'), 'uv')
(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1]))
"""
pass
def in_degree(self, v):
src, dst, eid = self._graph.edges(self._etypes_invmap[etype], order)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def in_degree(self, etype, v):
"""Return the in-degree of node ``v``.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node ID of destination type.
Returns
-------
etype : (str, str, str)
The source-edge-destination type triplet
int
The in-degree.
Examples
--------
Find how many users are playing Game #0 (Tetris):
>>> g['user', 'game', 'plays'].in_degree(0)
>>> g.in_degree(('user', 'plays', 'game'), 0)
2
See Also
--------
in_degrees
"""
pass
return self._graph.in_degree(self._etypes_invmap[etype], v)
def in_degrees(self, v=ALL):
def in_degrees(self, etype, v=ALL):
"""Return the array `d` of in-degrees of the node array `v`.
`d[i]` is the in-degree of node `v[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor, optional.
The node ID array of destination type. Default is to return the
degrees of all the nodes.
......@@ -850,27 +814,28 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find how many users are playing Game #0 and #1 (Tetris and Minecraft):
>>> g['user', 'game', 'plays'].in_degrees([0, 1])
>>> g.in_degrees(('user', 'plays', 'game'), [0, 1])
tensor([2, 2])
See Also
--------
in_degree
"""
pass
etype_idx = self._etypes_invmap[etype]
_, dsttype_idx = self._endpoint_types(etype_idx)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(dsttype_idx)))
else:
v = utils.toindex(v)
return self._graph.in_degrees(etype_idx, v).tousertensor()
def out_degree(self, v):
def out_degree(self, etype, v):
"""Return the out-degree of node `v`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node ID of source type.
......@@ -882,29 +847,24 @@ class DGLBaseHeteroGraph(object):
Examples
--------
Find how many games User #0 Alice is playing
>>> g['user', 'game', 'plays'].out_degree(0)
>>> g.out_degree(('user', 'plays', 'game'), 0)
1
See Also
--------
out_degrees
"""
pass
return self._graph.out_degree(self._etypes_invmap[etype], v)
def out_degrees(self, v=ALL):
def out_degrees(self, etype, v=ALL):
"""Return the array `d` of out-degrees of the node array `v`.
`d[i]` is the out-degree of node `v[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor
The node ID array of source type. Default is to return the degrees
of all the nodes.
......@@ -919,58 +879,413 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find how many games User #0 and #1 (Alice and Bob) are playing
>>> g['user', 'game', 'plays'].out_degrees([0, 1])
>>> g.out_degrees(('user', 'plays', 'game'), [0, 1])
tensor([1, 2])
See Also
--------
out_degree
"""
etype_idx = self._etypes_invmap[etype]
srctype_idx, _ = self._endpoint_types(etype_idx)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(srctype_idx)))
else:
v = utils.toindex(v)
return self._graph.out_degrees(etype_idx, v).tousertensor()
def bipartite_from_edge_list(u, v, num_src=None, num_dst=None):
"""Create a bipartite graph component of a heterogeneous graph with a
list of edges.
Parameters
----------
u, v : list[int]
List of source and destination node IDs.
num_src : int, optional
The number of nodes of source type.
By default, the value is the maximum of the source node IDs in the
edge list plus 1.
num_dst : int, optional
The number of nodes of destination type.
By default, the value is the maximum of the destination node IDs in
the edge list plus 1.
"""
num_src = num_src or (max(u) + 1)
num_dst = num_dst or (max(v) + 1)
u = utils.toindex(u)
v = utils.toindex(v)
return heterograph_index.create_bipartite_from_coo(num_src, num_dst, u, v)
def bipartite_from_scipy(spmat, with_edge_id=False):
"""Create a bipartite graph component of a heterogeneous graph with a
scipy sparse matrix.
Parameters
----------
spmat : scipy sparse matrix
The bipartite graph matrix whose rows represent sources and columns
represent destinations.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
"""
spmat = spmat.tocsr()
num_src, num_dst = spmat.shape
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
return heterograph_index.create_bipartite_from_csr(num_src, num_dst, indptr, indices, data)
class DGLHeteroGraph(DGLBaseHeteroGraph):
"""Base heterogeneous graph class.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Parameters
----------
graph_data :
The graph data. It can be one of the followings:
* (nx.MultiDiGraph, dict[str, list[tuple[int, int]]])
* (nx.MultiDiGraph, dict[str, scipy.sparse.matrix])
The first element is the metagraph of the heterogeneous graph, as a
networkx directed graph. Its nodes represent the node types, and
its edges represent the edge types. The edge type name should be
stored as edge keys.
The second element is a mapping from edge type to edge list. The
edge list can be either a list of (u, v) pairs, or a scipy sparse
matrix whose rows represents sources and columns represents
destinations. The edges will be added in the (source, destination)
order.
node_frames : dict[str, dict[str, Tensor]]
The node frames for each node type
edge_frames : dict[str, dict[str, Tensor]]
The edge frames for each edge type
multigraph : bool
Whether the heterogeneous graph is a multigraph.
readonly : bool
Whether the heterogeneous graph is readonly.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> mg = nx.MultiDiGraph([('user', 'user', 'follows'),
... ('user', 'game', 'plays'),
... ('developer', 'game', 'develops')])
>>> g = DGLHeteroGraph(
... mg, {
... 'follows': [(0, 1), (1, 2)],
... 'plays': [(0, 0), (1, 0), (1, 1), (2, 1)],
... 'develops': [(0, 0), (1, 1)]})
Then one can query the graph structure as follows:
>>> g['user'].number_of_nodes()
3
>>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees() # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g['develops'].in_edges(0) # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
Notes
-----
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument
def __init__(
self,
graph_data=None,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True,
_view_ntype_idx=None,
_view_etype_idx=None):
assert readonly, "Only readonly heterogeneous graphs are supported"
# Creating a view of another graph?
if isinstance(graph_data, DGLHeteroGraph):
super(DGLHeteroGraph, self).__init__(
graph_data._graph, graph_data._ntypes, graph_data._etypes,
graph_data._ntypes_invmap, graph_data._etypes_invmap,
graph_data._view_ntype_idx, graph_data._view_etype_idx)
self._node_frames = graph_data._node_frames
self._edge_frames = graph_data._edge_frames
self._msg_frames = graph_data._msg_frames
self._msg_indices = graph_data._msg_indices
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
return
if isinstance(graph_data, tuple):
metagraph, edges_by_type = graph_data
if not isinstance(metagraph, nx.MultiDiGraph):
raise TypeError('Metagraph should be networkx.MultiDiGraph')
# create metagraph graph index
srctypes, dsttypes, etypes = [], [], []
ntypes = []
ntypes_invmap = {}
etypes_invmap = {}
for srctype, dsttype, etype in metagraph.edges(keys=True):
srctypes.append(srctype)
dsttypes.append(dsttype)
etypes_invmap[(srctype, etype, dsttype)] = len(etypes_invmap)
etypes.append((srctype, etype, dsttype))
if srctype not in ntypes_invmap:
ntypes_invmap[srctype] = len(ntypes_invmap)
ntypes.append(srctype)
if dsttype not in ntypes_invmap:
ntypes_invmap[dsttype] = len(ntypes_invmap)
ntypes.append(dsttype)
srctypes = [ntypes_invmap[srctype] for srctype in srctypes]
dsttypes = [ntypes_invmap[dsttype] for dsttype in dsttypes]
metagraph_index = graph_index.create_graph_index(
list(zip(srctypes, dsttypes)), None, True) # metagraph is always immutable
# create base bipartites
bipartites = []
num_nodes = defaultdict(int)
# count the number of nodes for each type
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
num_src, num_dst = edges.shape
elif isinstance(edges, list):
u, v = zip(*edges)
num_src = max(u) + 1
num_dst = max(v) + 1
else:
raise TypeError('unknown edge list type %s' % type(edges))
num_nodes[srctype] = max(num_nodes[srctype], num_src)
num_nodes[dsttype] = max(num_nodes[dsttype], num_dst)
# create actual objects
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
bipartite = bipartite_from_scipy(edges)
elif isinstance(edges, list):
u, v = zip(*edges)
bipartite = bipartite_from_edge_list(
u, v, num_nodes[srctype], num_nodes[dsttype])
bipartites.append(bipartite)
hg_index = heterograph_index.create_heterograph(metagraph_index, bipartites)
super(DGLHeteroGraph, self).__init__(hg_index, ntypes, etypes)
else:
raise TypeError('Unrecognized graph data type %s' % type(graph_data))
# node and edge frame
if node_frames is None:
self._node_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
for i in range(len(self._ntypes))]
else:
self._node_frames = node_frames
if edge_frames is None:
self._edge_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
for i in range(len(self._etypes))]
else:
self._edge_frames = edge_frames
# message indicators
self._msg_indices = [None] * len(self._etypes)
self._msg_frames = []
for i in range(len(self._etypes)):
frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
def _create_view(self, ntype_idx, etype_idx):
return DGLHeteroGraph(
graph_data=self, _view_ntype_idx=ntype_idx, _view_etype_idx=etype_idx)
def _get_msg_index(self):
if self._msg_indices[self._current_etype_idx] is None:
self._msg_indices[self._current_etype_idx] = utils.zero_index(
size=self._graph.number_of_edges(self._current_etype_idx))
return self._msg_indices[self._current_etype_idx]
def _set_msg_index(self, index):
self._msg_indices[self._current_etype_idx] = index
def __getitem__(self, key):
if key in self._etypes_invmap:
return self._create_view(None, self._etypes_invmap[key])
else:
raise KeyError(key)
@property
def _node_frame(self):
# overrides DGLGraph._node_frame
return self._node_frames[self._current_ntype_idx]
@property
def _edge_frame(self):
# overrides DGLGraph._edge_frame
return self._edge_frames[self._current_etype_idx]
@property
def _src_frame(self):
# overrides DGLGraph._src_frame
return self._node_frames[self._current_srctype_idx]
@property
def _dst_frame(self):
# overrides DGLGraph._dst_frame
return self._node_frames[self._current_dsttype_idx]
@property
def _msg_frame(self):
# overrides DGLGraph._msg_frame
return self._msg_frames[self._current_etype_idx]
def add_nodes(self, node_type, num, data=None):
"""Add multiple new nodes of the same node type
Parameters
----------
node_type : str
Type of the added nodes. Must appear in the metagraph.
num : int
Number of nodes to be added.
data : dict, optional
Feature data of the added nodes.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['game'].number_of_nodes()
2
>>> g.add_nodes(3, 'game') # add 3 new games
>>> g['game'].number_of_nodes()
5
"""
pass
def add_edge(self, etype, u, v, data=None):
"""Add an edge of ``etype`` between u of the source node type, and v
of the destination node type..
class DGLBaseHeteroGraphView(DGLBaseHeteroGraph):
"""View on a heterogeneous graph, constructed from
DGLBaseHeteroGraph.__getitem__().
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The source node ID of type ``utype``. Must exist in the graph.
v : int
The destination node ID of type ``vtype``. Must exist in the
graph.
data : dict, optional
Feature data of the added edge.
It is semantically the same as a subgraph, except that
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
* The subgraph itself is not materialized until the user explicitly
queries the subgraph structure (e.g. calling ``in_edges``, but not
``update_all``).
>>> g['plays'].number_of_edges()
4
>>> g.add_edge(2, 0, 'plays')
>>> g['plays'].number_of_edges()
5
"""
pass
class DGLHeteroGraph(DGLBaseHeteroGraph):
"""Base heterogeneous graph class.
The graph stores nodes, edges and also their (type-specific) features.
Heterogeneous graphs are by default multigraphs.
def add_edges(self, u, v, etype, data=None):
"""Add multiple edges of ``etype`` between list of source nodes ``u``
and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
Parameters
----------
metagraph, number_of_nodes_by_type, edge_connections_by_type :
See DGLBaseHeteroGraph
node_frame : dict[str, FrameRef], optional
Node feature storage per type
edge_frame : dict[str, FrameRef], optional
Edge feature storage per type
readonly : bool, optional
Whether the graph structure is read-only (default: False)
u : list, tensor
The source node IDs of type ``utype``. Must exist in the graph.
v : list, tensor
The destination node IDs of type ``vtype``. Must exist in the
graph.
etype : (str, str, str)
The source-edge-destination type triplet
data : dict, optional
Feature data of the added edge.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['plays'].number_of_edges()
4
>>> g.add_edges([0, 2], [1, 0], 'plays')
>>> g['plays'].number_of_edges()
6
"""
# pylint: disable=unused-argument
def __init__(
self,
metagraph,
number_of_nodes_by_type,
edge_connections_by_type,
node_frame=None,
edge_frame=None,
readonly=False):
super(DGLHeteroGraph, self).__init__(
metagraph, number_of_nodes_by_type, edge_connections_by_type)
pass
def from_networkx(
self,
......@@ -1020,10 +1335,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
pass
def node_attr_schemes(self, ntype):
"""Return the node feature schemes for a given node type.
"""Return the node feature schemes.
Each feature scheme is a named tuple that stores the shape and data type
of the node feature
of the node feature.
Parameters
----------
......@@ -1034,150 +1349,90 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
-------
dict of str to schemes
The schemes of node feature columns.
Examples
--------
The following uses PyTorch backend.
>>> g.ndata['user']['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)}
"""
pass
return self._node_frames[self._ntypes_invmap[ntype]].schemes
def edge_attr_schemes(self, etype):
"""Return the edge feature schemes for a given edge type.
"""Return the edge feature schemes.
Each feature scheme is a named tuple that stores the shape and data type
of the edge feature
of the edge feature.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
etype : (str, str, str)
The source-edge-destination type triplet
Returns
-------
dict of str to schemes
The schemes of node feature columns.
"""
pass
def set_n_initializer(self, ntype, initializer, field=None):
"""Set the initializer for empty node features of given type.
Initializer is a callable that returns a tensor given the shape, data type
and device context.
When a subset of the nodes are assigned a new feature, initializer is
used to create feature for rest of the nodes.
Parameters
----------
ntype : str
The node type name.
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
"""
pass
def set_e_initializer(self, etype, initializer, field=None):
"""Set the initializer for empty edge features of given type.
Initializer is a callable that returns a tensor given the shape, data
type and device context.
When a subset of the edges are assigned a new feature, initializer is
used to create feature for rest of the edges.
Examples
--------
The following uses PyTorch backend.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
>>> g.edata['user', 'plays', 'game']['h'] = torch.randn(4, 4)
>>> g.edge_attr_schemes(('user', 'plays', 'game'))
{'h': Scheme(shape=(4,), dtype=torch.float32)}
"""
pass
return self._edge_frames[self._etypes_invmap[etype]].schemes
@property
def nodes(self):
"""Return a node view that can used to set/get feature data of a
single node type.
Notes
-----
An error is raised if the graph contains multiple node types. Use
g[ntype]
to select nodes with type ``ntype``.
Examples
--------
To set features of User #0 and #2 in a heterogeneous graph:
>>> g['user'].nodes[[0, 2]].data['h'] = torch.zeros(2, 5)
>>> g.nodes['user'][[0, 2]].data['h'] = torch.zeros(2, 5)
"""
pass
return HeteroNodeView(self)
@property
def ndata(self):
"""Return the data view of all the nodes of a single node type.
Notes
-----
An error is raised if the graph contains multiple node types. Use
g[ntype]
to select nodes with type ``ntype``.
Examples
--------
To set features of games in a heterogeneous graph:
>>> g['game'].ndata['h'] = torch.zeros(2, 5)
>>> g.ndata['game']['h'] = torch.zeros(2, 5)
"""
pass
return HeteroNodeDataView(self)
@property
def edges(self):
"""Return an edges view that can used to set/get feature data of a
single edge type.
Notes
-----
An error is raised if the graph contains multiple edge types. Use
g[src_type, dst_type, edge_type]
to select edges with type ``(src_type, dst_type, edge_type)``.
Examples
--------
To set features of gameplays #1 (Bob -> Tetris) and #3 (Carol ->
Minecraft) in a heterogeneous graph:
>>> g['user', 'game', 'plays'].edges[[1, 3]].data['h'] = torch.zeros(2, 5)
>>> g.edges['user', 'plays', 'game'][[1, 3]].data['h'] = torch.zeros(2, 5)
"""
pass
return HeteroEdgeView(self)
@property
def edata(self):
"""Return the data view of all the edges of a single edge type.
Notes
-----
An error is raised if the graph contains multiple edge types. Use
g[src_type, dst_type, edge_type]
to select edges with type ``(src_type, dst_type, edge_type)``.
Examples
--------
>>> g['developer', 'game', 'develops'].edata['h'] = torch.zeros(2, 5)
>>> g.edata['developer', 'develops', 'game']['h'] = torch.zeros(2, 5)
"""
pass
return HeteroEdgeDataView(self)
def set_n_repr(self, data, ntype, u=ALL, inplace=False):
def set_n_repr(self, ntype, data, u=ALL, inplace=False):
"""Set node(s) representation of a single node type.
`data` is a dictionary from the feature name to feature tensor. Each tensor
......@@ -1190,16 +1445,32 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
ntype : str
The node type
data : dict of tensor
Node representation.
ntype : str
Node type.
u : node, container or tensor
The node(s).
inplace : bool
If True, update will be done in place, but autograd will break.
"""
pass
ntype = self._ntypes_invmap[ntype]
if is_all(u):
num_nodes = self._graph.number_of_nodes(ntype)
else:
u = utils.toindex(u)
num_nodes = len(u)
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
if is_all(u):
for key, val in data.items():
self._node_frames[ntype][key] = val
else:
self._node_frames[ntype].update_rows(u, data, inplace=inplace)
def get_n_repr(self, ntype, u=ALL):
"""Get node(s) representation of a single node type.
......@@ -1209,7 +1480,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
ntype : str
Node type.
The node type
u : node, container or tensor
The node(s).
......@@ -1218,7 +1489,14 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict
Representation dict from feature name to feature tensor.
"""
pass
if len(self.node_attr_schemes(ntype)) == 0:
return dict()
ntype_idx = self._ntypes_invmap[ntype]
if is_all(u):
return dict(self._node_frames[ntype_idx])
else:
u = utils.toindex(u)
return self._node_frames[ntype_idx].select_rows(u)
def pop_n_repr(self, ntype, key):
"""Get and remove the specified node repr of a given node type.
......@@ -1226,7 +1504,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
ntype : str
The node type.
The node type
key : str
The attribute name.
......@@ -1235,9 +1513,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor
The popped representation
"""
pass
ntype = self._ntypes_invmap[ntype]
return self._node_frames[ntype].pop(key)
def set_e_repr(self, data, etype, edges=ALL, inplace=False):
def set_e_repr(self, etype, data, edges=ALL, inplace=False):
"""Set edge(s) representation of a single edge type.
`data` is a dictionary from the feature name to feature tensor. Each tensor
......@@ -1249,11 +1528,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
data : tensor or dict of tensor
Edge representation.
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
edges : edges
Edges can be either
......@@ -1265,16 +1543,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
inplace : bool
If True, update will be done in place, but autograd will break.
"""
pass
etype_idx = self._etypes_invmap[etype]
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
# sanity check
if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(data))
if is_all(eid):
num_edges = self._graph.number_of_edges(etype_idx)
else:
eid = utils.toindex(eid)
num_edges = len(eid)
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.'
' Got %d and %d instead.' % (nfeats, num_edges))
# set
if is_all(eid):
# update column
for key, val in data.items():
self._edge_frames[etype_idx][key] = val
else:
# update row
self._edge_frames[etype_idx].update_rows(eid, data, inplace=inplace)
def get_e_repr(self, etype, edges=ALL):
"""Get edge(s) representation.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
etype : (str, str, str)
The source-edge-destination type triplet
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
......@@ -1284,16 +1596,34 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict
Representation dict
"""
pass
etype_idx = self._etypes_invmap[etype]
if len(self.edge_attr_schemes(etype)) == 0:
return dict()
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
if is_all(eid):
return dict(self._edge_frames[etype_idx])
else:
eid = utils.toindex(eid)
return self._edge_frames[etype_idx].select_rows(eid)
def pop_e_repr(self, etype, key):
"""Get and remove the specified edge repr of a single edge type.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
etype : (str, str, str)
The source-edge-destination type triplet
key : str
The attribute name.
......@@ -1302,7 +1632,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor
The popped representation
"""
pass
etype = self._etypes_invmap[etype]
self._edge_frames[etype].pop(key)
def register_message_func(self, func):
"""Register global message function for each edge type provided.
......@@ -1314,17 +1645,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[etype, callable]
func : callable
Message function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
See Also
--------
send
......@@ -1333,7 +1657,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
push
update_all
"""
pass
raise NotImplementedError
def register_reduce_func(self, func):
"""Register global message reduce function for each edge type provided.
......@@ -1345,17 +1669,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[etype, callable]
func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
See Also
--------
recv
......@@ -1364,7 +1681,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
pull
update_all
"""
pass
raise NotImplementedError
def register_apply_node_func(self, func):
"""Register global node apply function for each node type provided.
......@@ -1376,21 +1693,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[str, callable]
func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``func`` is not a
dict, it will throw an error.
See Also
--------
apply_nodes
register_apply_edge_func
"""
pass
raise NotImplementedError
def register_apply_edge_func(self, func):
"""Register global edge apply function for each edge type provided.
......@@ -1400,23 +1712,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[etype, callable]
func : callable
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
See Also
--------
apply_edges
register_apply_node_func
"""
pass
raise NotImplementedError
def apply_nodes(self, func, v=ALL, inplace=False):
"""Apply the function on the nodes with the same type to update their
......@@ -1426,40 +1731,35 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[str, callable], or None
func : dict[str, callable] or None
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``func`` is not a
dict, it will throw an error.
v : int, iterable of int, tensor, dict, optional
v : dict[str, int or iterable of int or tensor], optional
The (type-specific) node (ids) on which to apply ``func``.
If ``func`` is not a dict, then ``v`` must not be a dict.
If ``func`` is a dict, then ``v`` must either be
* ALL: for computing on all nodes with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the nodes to be updated for each type.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
Examples
--------
>>> g['user'].ndata['h'] = torch.ones(3, 5)
>>> g['user'].apply_nodes(lambda x: {'h': x * 2})
>>> g['user'].ndata['h']
>>> g.ndata['user']['h'] = torch.ones(3, 5)
>>> g.apply_nodes({'user': lambda nodes: {'h': nodes.data['h'] * 2}})
>>> g.ndata['user']['h']
tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]])
>>> g.apply_nodes({'user': lambda x: {'h': x * 2}})
>>> g['user'].ndata['h']
tensor([[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]])
"""
pass
for ntype, nfunc in func.items():
if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
else:
v_ntype = utils.toindex(v[ntype])
with ir.prog() as prog:
scheduler.schedule_apply_nodes(
graph=self._create_view(self._ntypes_invmap[ntype], None),
v=v_ntype,
apply_func=nfunc,
inplace=inplace)
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, inplace=False):
"""Apply the function on the edges with the same type to update their
......@@ -1469,42 +1769,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : callable, dict[etype, callable], or None
func : dict[(str, str, str), callable] or None
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
edges : any valid edge specification, dict, optional
edges : dict[(str, str, str), any valid edge specification], optional
Edges on which to apply ``func``. See :func:`send` for valid
edge specification.
If ``func`` is not a dict, then ``edges`` must not be a dict.
If ``func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the edges to be updated for each type.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Examples
--------
>>> g['user', 'game', 'plays'].edata['h'] = torch.ones(3, 5)
>>> g['user', 'game', 'plays'].apply_edges(lambda x: {'h': x * 2})
>>> g['user', 'game', 'plays'].edata['h']
>>> g.edata['user', 'plays', 'game']['h'] = torch.ones(4, 5)
>>> g.apply_edges(
... {('user', 'plays', 'game'): lambda edges: {'h': edges.data['h'] * 2}})
>>> g.edata['user', 'plays', 'game']['h']
tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]])
>>> g.apply_edges({('user', 'game', 'plays'): lambda x: {'h': x * 2}})
tensor([[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]])
"""
pass
for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
"""Group the edges by nodes and apply the function of the grouped
......@@ -1515,33 +1823,46 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
----------
group_by : str
Specify how to group edges. Expected to be either 'src' or 'dst'
func : callable, dict[etype, callable]
func : dict[(str, str, str), callable]
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should
be (bucket_size, degrees, *feature_shape), and
return the dict with values of the same shapes.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
edges : valid edges type, dict, optional
edges : dict[(str, str, str), valid edges type], optional
Edges on which to group and apply ``func``. See :func:`send` for valid
edges type. Default is all the edges.
If ``func`` is not a dict, then ``edges`` must not be a dict.
If ``func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the edges to be updated for each type.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
pass
if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst")
for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx)
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
group_by=group_by,
inplace=inplace)
Runtime.run(prog)
# TODO: REVIEW
def send(self, edges=ALL, message_func=None):
"""Send messages along the given edges with the same edge type.
......@@ -1553,7 +1874,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
* ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids.
* ``pair of int iterable`` / ``pair of tensors`` :
Specify multiple edges using their endpoints.
* a dict of all the above, if ``message_func`` is a dict.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].send(edges, message_func)
The UDF returns messages on the edges and can be later fetched in
the destination node's ``mailbox``. Receiving will consume the messages.
......@@ -1564,34 +1891,43 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
edges : valid edges type, dict, optional
edges : valid edges type, optional
Edges on which to apply ``message_func``. Default is sending along all
the edges.
If ``message_func`` is not a dict, then ``edges`` must not be a dict.
If ``message_func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in
``message_func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``message_func``, indicating the edges to be updated for each
type.
message_func : callable, dict[etype, callable]
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
Notes
-----
On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
along all edges between :math:`u` and :math:`v`.
"""
pass
assert not utils.is_dict_like(message_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(self._current_etype_idx)))
u, v, _ = self._graph.edges(self._current_etype_idx)
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
if len(eid) == 0:
# no edge to be triggered
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)
def recv(self,
v=ALL,
......@@ -1615,53 +1951,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
The provided UDF maybe called multiple times so it is recommended to provide
function with no side effect.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].recv(v, reduce_func, apply_node_func, inplace)
Parameters
----------
v : int, container or tensor, dict, optional
v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes.
If ``apply_node_func`` is not a dict, then ``v`` must not be a
dict.
If ``apply_node_func`` is a dict, then ``v`` must either be
* ALL: for computing on all nodes with the given types in
``apply_node_func``.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
reduce_func : callable, dict[etype, callable], optional
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable]
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``v``, the node types in ``apply_node_func``,
and the destination types in ``reduce_func`` must be the same.
"""
pass
assert not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert reduce_func is not None
if is_all(v):
v = F.arange(0, self._graph.number_of_nodes(self._current_dsttype_idx))
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
# no vertex to be triggered.
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self,
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def send_and_recv(self,
edges,
message_func="default",
reduce_func="default",
apply_node_func="default",
message_func=None,
reduce_func=None,
apply_node_func=None,
inplace=False):
"""Send messages along edges with the same edge type, and let destinations
receive them.
......@@ -1673,66 +2009,65 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
``recv(self, dst, reduce_func, apply_node_func)``, where ``dst``
are the destinations of the ``edges``.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].send_and_recv(edges, message_func, reduce_func, apply_node_func, inplace)
Parameters
----------
edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid
edges type.
If the functions are not dicts, then ``edges`` must not be a dict.
If the functions are dicts, then ``edges`` must either be
* ALL: for computing on all edges with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the edges
to be updated for each type.
message_func : callable, dict[etype, callable], optional
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the destination type of ``edges``, the node types in
``apply_node_func``, and the destination types in ``reduce_func``
must be the same.
* the edge type of ``edges``, ``message_func`` and ``reduce_func``
must also be the same.
"""
pass
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
if len(u) == 0:
# no edges to be triggered
return
with ir.prog() as prog:
scheduler.schedule_snr(graph=self,
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def pull(self,
v,
message_func="default",
reduce_func="default",
apply_node_func="default",
message_func=None,
reduce_func=None,
apply_node_func=None,
inplace=False):
"""Pull messages from the node(s)' predecessors and then update their features.
......@@ -1744,126 +2079,102 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
by the column initializer (see :func:`set_n_initializer`). The feature shapes and
dtypes will be inferred.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].pull(v, message_func, reduce_func, apply_node_func, inplace)
Parameters
----------
v : int, container or tensor, dict, optional
v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes.
If the functions are not dicts, then ``v`` must not be a dict.
If the functions are dicts, then ``v`` must either be
* ALL: for computing on all nodes with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
message_func : callable, dict[etype, callable], optional
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types of ``v``, the node types in ``apply_node_func``,
and the destination types in ``reduce_func`` must be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
"""
pass
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
v = utils.toindex(v)
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self,
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def push(self,
u,
message_func="default",
reduce_func="default",
apply_node_func="default",
message_func=None,
reduce_func=None,
apply_node_func=None,
inplace=False):
"""Send message from the node(s) to their successors and update them.
Optionally, apply a function to update the node features after receive.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].push(e, message_func, reduce_func, apply_node_func, inplace)
Parameters
----------
u : int, container or tensor, dict
u : int, container or tensor
The node(s) to push messages out.
If the functions are not dicts, then ``v`` must not be a dict.
If the functions are dicts, then ``v`` must either be
* ALL: for computing on all nodes with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
message_func : callable, dict[etype, callable], optional
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``apply_node_func`` and the destination types in
``reduce_func`` must be the same.
* the source types of ``message_func`` and the node types of ``u`` must
be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
"""
pass
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
u = utils.toindex(u)
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self,
u=u,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def update_all(self,
message_func="default",
reduce_func="default",
apply_node_func="default"):
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Send messages through all edges and update all nodes.
Optionally, apply a function to update the node features after receive.
......@@ -1872,64 +2183,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
``send(self, self.edges(), message_func)`` and
``recv(self, self.nodes(), reduce_func, apply_node_func)``.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].update_all(message_func, reduce_func, apply_node_func)
Parameters
----------
message_func : callable, dict[etype, callable], optional
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``apply_node_func`` and the destination types in
``reduce_func`` must be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
"""
pass
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
Runtime.run(prog)
# TODO should we support this?
def prop_nodes(self,
nodes_generator,
message_func="default",
reduce_func="default",
apply_node_func="default"):
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Node propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
# TODO should we support this?
def prop_edges(self,
edges_generator,
message_func="default",
reduce_func="default",
apply_node_func="default"):
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Edge propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
......@@ -2167,8 +2467,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
"""
pass
# TODO: replace this after implementing frame
# pylint: disable=useless-super-delegation
def __repr__(self):
pass
return super(DGLHeteroGraph, self).__repr__()
# pylint: disable=abstract-method
class DGLHeteroSubGraph(DGLHeteroGraph):
......
"""Module for heterogeneous graph index class definition."""
from __future__ import absolute_import
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
from . import utils
@register_object('graph.HeteroGraph')
class HeteroGraphIndex(ObjectBase):
"""HeteroGraph index object.
Note
----
Do not create GraphIndex directly.
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
obj._cache = {}
return obj
def __getstate__(self):
# TODO
return
def __setstate__(self, state):
# TODO
pass
@property
def metagraph(self):
"""Meta graph
Returns
-------
GraphIndex
The meta graph.
"""
return _CAPI_DGLHeteroGetMetaGraph(self)
def number_of_ntypes(self):
"""Return number of node types."""
return self.metagraph.number_of_nodes()
def number_of_etypes(self):
"""Return number of edge types."""
return self.metagraph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
HeteroGraphIndex
The bipartite graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def add_nodes(self, ntype, num):
"""Add nodes.
Parameters
----------
ntype : int
Node type
num : int
Number of nodes to be added.
"""
_CAPI_DGLHeteroAddVertices(self, int(ntype), int(num))
self.clear_cache()
def add_edge(self, etype, u, v):
"""Add one edge.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))
self.clear_cache()
def add_edges(self, etype, u, v):
"""Add many edges.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
_CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor())
self.clear_cache()
def clear(self):
"""Clear the graph."""
_CAPI_DGLHeteroClear(self)
self._cache.clear()
def ctx(self):
"""Return the context of this graph index.
Returns
-------
DGLContext
The context of the graph.
"""
return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph.
Parameters
----------
etype : int
The edge type.
Returns
-------
int
The number of bits needed.
"""
stype, dtype = self.metagraph.find_edge(etype)
if (self.number_of_edges(etype) >= 0x80000000 or
self.number_of_nodes(stype) >= 0x80000000 or
self.number_of_nodes(dtype) >= 0x80000000):
return 64
else:
return 32
def asbits(self, bits):
"""Transform the graph to a new one with the given number of bits storage.
NOTE: this method only works for immutable graph index
Parameters
----------
bits : int
The number of integer bits (32 or 64)
Returns
-------
HeteroGraphIndex
The graph index stored using the given number of bits.
"""
return _CAPI_DGLHeteroAsNumBits(self, int(bits))
def copy_to(self, ctx):
"""Copy this immutable graph index to the given device context.
NOTE: this method only works for immutable graph index
Parameters
----------
ctx : DGLContext
The target device context.
Returns
-------
HeteroGraphIndex
The graph index on the given device context.
"""
return _CAPI_DGLHeteroCopyTo(self, ctx.device_type, ctx.device_id)
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsMultigraph(self))
def is_readonly(self):
"""Return whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsReadonly(self))
def number_of_nodes(self, ntype):
"""Return the number of nodes.
Parameters
----------
ntype : int
Node type
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))
def number_of_edges(self, etype):
"""Return the number of edges.
Parameters
----------
etype : int
Edge type
Returns
-------
int
The number of edges
"""
return _CAPI_DGLHeteroNumEdges(self, int(etype))
def has_node(self, ntype, vid):
"""Return true if the node exists.
Parameters
----------
ntype : int
Node type
vid : int
The nodes
Returns
-------
bool
True if the node exists, False otherwise.
"""
return bool(_CAPI_DGLHeteroHasVertex(self, int(ntype), int(vid)))
def has_nodes(self, ntype, vids):
"""Return true if the nodes exist.
Parameters
----------
ntype : int
Node type
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array))
def has_edge_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists, False otherwise
"""
return bool(_CAPI_DGLHeteroHasEdgeBetween(self, int(etype), int(u), int(v)))
def has_edges_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array))
def predecessors(self, etype, v):
"""Return the predecessors of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v)))
def successors(self, etype, v):
"""Return the successors of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v)))
def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
utils.Index
The edge id array.
"""
return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v)))
def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, etype, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, etype, v):
"""Return the in edges of the node(s).
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroInEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, etype, v):
"""Return the out edges of the node(s).
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroOutEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges')
def edges(self, etype, order=None):
"""Return all the edges
Parameters
----------
etype : int
Edge type
order : string
The order of the returned edges. Currently support:
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if order is None:
order = ""
edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
src = utils.toindex(src)
dst = utils.toindex(dst)
eid = utils.toindex(eid)
return src, dst, eid
def in_degree(self, etype, v):
"""Return the in degree of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLHeteroInDegree(self, int(etype), int(v))
def in_degrees(self, etype, v):
"""Return the in degrees of the nodes.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array))
def out_degree(self, etype, v):
"""Return the out degree of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLHeteroOutDegree(self, int(etype), int(v))
def out_degrees(self, etype, v):
"""Return the out degrees of the nodes.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array))
def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
ctx : context
The context of the returned matrix.
Returns
-------
SparseTensor
The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
# convert to framework-specific sparse matrix
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle
elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
else:
raise Exception("unknown format")
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
Parameters
----------
induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of
node types in this heterograph.
Returns
-------
SubgraphIndex
The subgraph index.
"""
vids = [nodes.todgltensor() for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids)
def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph.
Parameters
----------
induced_edges : list of utils.Index
Induced edges. The length should be equal to the number of
edge types in this heterograph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
SubgraphIndex
The subgraph index.
"""
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite')
def get_bipartite(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device
context.
Note: this internal function is for DGL scheduler use only
Parameters
----------
etype : int, or None
If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type.
ctx : DGLContext
The context of the returned graph.
Returns
-------
HeteroGraphIndex
"""
g = self.get_relation_graph(etype) if etype is not None else self
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype):
"""Return the edge shuffling order when a coo graph is converted to csr format
Parameters
----------
etype : int
The edge type
Returns
-------
tuple of two utils.Index
The first element of the tuple is the shuffle order for outward graph
The second element of the tuple is the shuffle order for inward graph
"""
csr = _CAPI_DGLHeteroGetAdj(self, int(etype), True, "csr")
order = csr(2)
rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr")
rev_order = rev_csr(2)
return utils.toindex(order), utils.toindex(rev_order)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
"""Hetero-subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Returns
-------
HeteroGraphIndex
The subgraph
"""
return _CAPI_DGLHeteroSubgraphGetGraph(self)
@property
def induced_nodes(self):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
Returns
-------
list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret]
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
indptr : utils.Index
CSR indptr.
indices : utils.Index
CSR indices.
edge_ids : utils.Index
Edge shuffle id.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
----------
metagraph : GraphIndex
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
_init_api("dgl.heterograph_index")
......@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule(
apply_func,
u, v, eid,
group_by,
var_nf,
var_src_nf,
var_dst_nf,
var_ef,
var_out):
"""Create degree bucketing schedule for group_apply_edge
......@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule(
Edges to apply
group_by: str
If "src", group by u. If "dst", group by v
var_nf : var.FEAT_DICT
The variable for node feature frame.
var_src_nf : var.FEAT_DICT
The variable for source feature frame.
var_dst_nf : var.FEAT_DICT
The variable for destination feature frame.
var_ef : var.FEAT_DICT
The variable for edge frame.
var_out : var.FEAT_DICT
......@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule(
var_v = var.IDX(v_bkt)
var_eid = var.IDX(eid_bkt)
# apply edge UDF on each bucket
fdsrc = ir.READ_ROW(var_nf, var_u)
fddst = ir.READ_ROW(var_nf, var_v)
fdsrc = ir.READ_ROW(var_src_nf, var_u)
fddst = ir.READ_ROW(var_dst_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid)
fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var
# save for merge
......
......@@ -8,6 +8,8 @@ from .. import backend as F
from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir
from .ir import var
......@@ -28,6 +30,15 @@ __all__ = [
"schedule_pull"
]
def _dispatch(graph, method, *args, **kwargs):
graph_index = graph._graph
if isinstance(graph_index, GraphIndex):
return getattr(graph._graph, method)(*args, **kwargs)
elif isinstance(graph_index, HeteroGraphIndex):
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func):
"""get send schedule
......@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func):
The message function
"""
var_mf = var.FEAT_DICT(graph._msg_frame)
var_nf = var.FEAT_DICT(graph._node_frame)
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_eid = var.IDX(eid)
......@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func):
v=v,
eid=eid,
mfunc=message_func,
var_src_nf=var_nf,
var_dst_nf=var_nf,
var_src_nf=var_src_nf,
var_dst_nf=var_dst_nf,
var_ef=var_ef)
# write tmp msg back
......@@ -83,7 +95,7 @@ def schedule_recv(graph,
inplace: bool
If True, the update will be done in place
"""
src, dst, eid = graph._graph.in_edges(recv_nodes)
src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
......@@ -96,7 +108,7 @@ def schedule_recv(graph,
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
# sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
......@@ -105,12 +117,12 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero():
......@@ -148,7 +160,7 @@ def schedule_snr(graph,
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
......@@ -156,11 +168,11 @@ def schedule_snr(graph,
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph.number_of_nodes())
edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame,
dst_node_frame=graph._node_frame,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
message_func=message_func,
reduce_func=reduce_func,
......@@ -170,12 +182,12 @@ def schedule_snr(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat,
apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph,
message_func,
......@@ -194,27 +206,27 @@ def schedule_update_all(graph,
apply_func: callable
The apply node function
"""
if graph.number_of_edges() == 0:
if graph._number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes()))
nodes = utils.toindex(slice(0, graph._number_of_dst_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else:
eid = utils.toindex(slice(0, graph.number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # ALL
eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid)
# generate send + reduce
def uv_getter():
src, dst, _ = graph._graph.edges('eid')
src, dst, _ = _dispatch(graph, 'edges', 'eid')
return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame,
dst_node_frame=graph._node_frame,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
message_func=message_func,
reduce_func=reduce_func,
......@@ -224,9 +236,9 @@ def schedule_update_all(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat)
ir.WRITE_DICT_(var_dst_nf, final_feat)
def schedule_apply_nodes(graph,
v,
......@@ -326,10 +338,12 @@ def schedule_apply_edges(graph,
A list of executors for DGL Runtime
"""
# vars
var_nf = var.FEAT_DICT(graph._node_frame)
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_nf, var_dst_nf=var_nf, var_ef=var_ef)
var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid)
# schedule apply edges
......@@ -401,7 +415,7 @@ def schedule_push(graph,
inplace: bool
If True, the update will be done in place
"""
u, v, eid = graph._graph.out_edges(u)
u, v, eid = _dispatch(graph, 'out_edges', u)
if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled.
return
......@@ -434,7 +448,7 @@ def schedule_pull(graph,
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes)
u, v, eid = _dispatch(graph, 'in_edges', pull_nodes)
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
......@@ -443,27 +457,27 @@ def schedule_pull(graph,
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
num_nodes = graph.number_of_nodes()
adj_creator = lambda: spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._node_frame,
graph._node_frame, graph._edge_frame,
reduced_feat = _gen_send_reduce(graph, graph._src_frame,
graph._dst_frame, graph._edge_frame,
message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator,
out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph,
u, v, eid,
......@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph,
A list of executors for DGL Runtime
"""
# vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by,
var_nf, var_ef, var_out)
var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid)
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
......@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame.
# TODO(minjie): should replace this with an IR call to make the program
# stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes)))
# vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
var_nf = var.FEAT_DICT(graph._node_frame, 'nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf')
var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
num_nodes = graph.number_of_nodes()
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), num_nodes)
(src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
# using edge map instead of message map because messages are in global
# message frame
var_out_map = _build_idx_map(recv_nodes, nbits)
......@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else:
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes,
var_nf, var_msg, var_out)
var_dst_nf, var_msg, var_out)
return var_out
def _gen_send_reduce(
......@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid)
if mfunc_is_list:
if eid.is_slice(0, graph.number_of_edges()):
if eid.is_slice(0, graph._number_of_edges()):
# full graph case
res = spmv.build_gidx_and_mapping_graph(graph)
else:
num_nodes = graph.number_of_nodes()
res = spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes)
res = spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
adj, edge_map, _ = res
# create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid)))
......
"""Module for SPMV rules."""
from __future__ import absolute_import
from functools import partial
from ..base import DGLError
from .. import backend as F
from .. import utils
from .. import ndarray as nd
from ..graph_index import from_coo
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir
from .ir import var
......@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters
----------
graph : DGLGraph
The graph
graph : DGLGraph or DGLHeteroGraph
The homogeneous graph, or a bipartite view of the heterogeneous graph.
Returns
-------
......@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph):
Number of ints needed to represent the graph
"""
gidx = graph._graph
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of
......@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
---------
edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid)
num_nodes : int
The number of nodes.
num_src, num_dst : int
The number of source and destination nodes.
Returns
-------
......@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
Number of ints needed to represent the graph
"""
u, v, eid = edge_tuples
gidx = from_coo(num_nodes, u, v, None, True)
forward, backward = gidx.get_csr_shuffle_order()
gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor()
nbits = gidx.bits_needed()
nbits = gidx.bits_needed(0)
forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits)
backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits)
forward_map = F.zerocopy_to_dgl_ndarray(forward_map)
......@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx)))
return gidx.get_immutable_gidx, edge_map, nbits
return partial(gidx.get_bipartite, None), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
......@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
eid = utils.toindex(eid)
else:
u, v, eid = edge_tuples
num_nodes = max(graph.layer_size(block_id), graph.layer_size(block_id + 1))
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_nodes)
num_src, num_dst = graph.layer_size(block_id), graph.layer_size(block_id + 1)
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_src, num_dst)
return gidx, edge_map, nbits
......@@ -242,3 +242,229 @@ class BlockDataView(MutableMapping):
def __repr__(self):
data = self._graph._edge_frames[self._flow]
return repr({key : data[key] for key in data})
class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, ntype):
return HeteroNodeTypeView(self._graph, ntype)
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL))
else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype', '_nodes']
def __init__(self, graph, ntype, nodes):
self._graph = graph
self._ntype = ntype
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val})
def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key)
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __repr__(self):
data = self._graph.get_n_repr(self._ntype)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, etype):
return HeteroEdgeTypeView(self._graph, etype)
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL))
else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges))
def __call__(self):
"""Return the edges."""
return F.arange(0, len(self))
class HeteroEdgeTypeDataView(MutableMapping):
"""The data view class when G.edges[etype][...].data is called.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype', '_edges']
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val})
def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key)
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __repr__(self):
data = self._graph.get_e_repr(self._etype)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
......@@ -52,11 +52,6 @@ enum BoolFlag {
dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<dgl::runtime::NDArray>& vec);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
* \brief Copy a vector to an int64_t NDArray.
*
......
......@@ -6,12 +6,15 @@
#include <dgl/array.h>
#include <dgl/lazy.h>
#include <dgl/immutable_graph.h>
#include <dgl/base_heterograph.h>
#include "./bipartite.h"
#include "../c_api_common.h"
namespace dgl {
namespace {
inline GraphPtr CreateBipartiteMetaGraph() {
std::vector<int64_t> row_vec(1, Bipartite::kSrcVType);
std::vector<int64_t> col_vec(1, Bipartite::kDstVType);
......@@ -20,8 +23,9 @@ inline GraphPtr CreateBipartiteMetaGraph() {
GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
return g;
}
static const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph();
} // namespace
const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph();
}; // namespace
//////////////////////////////////////////////////////////
//
......@@ -29,22 +33,20 @@ static const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph();
//
//////////////////////////////////////////////////////////
/*! \brief COO graph */
class Bipartite::COO : public BaseHeteroGraph {
public:
COO(int64_t num_src, int64_t num_dst,
IdArray src, IdArray dst)
COO(int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
: BaseHeteroGraph(kBipartiteMetaGraph) {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
}
COO(int64_t num_src, int64_t num_dst,
IdArray src, IdArray dst, bool is_multigraph)
COO(int64_t num_src, int64_t num_dst, IdArray src, IdArray dst, bool is_multigraph)
: BaseHeteroGraph(kBipartiteMetaGraph),
is_multigraph_(is_multigraph) {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
}
explicit COO(const aten::COOMatrix& coo)
: BaseHeteroGraph(kBipartiteMetaGraph), adj_(coo) {}
explicit COO(const aten::COOMatrix& coo) : BaseHeteroGraph(kBipartiteMetaGraph), adj_(coo) {}
uint64_t NumVertexTypes() const override {
return 2;
......@@ -155,7 +157,7 @@ class Bipartite::COO : public BaseHeteroGraph {
}
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
eids};
......@@ -288,7 +290,6 @@ class Bipartite::COO : public BaseHeteroGraph {
//
//////////////////////////////////////////////////////////
/*! \brief CSR graph */
class Bipartite::CSR : public BaseHeteroGraph {
public:
......@@ -305,8 +306,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}
explicit CSR(const aten::CSRMatrix& csr)
: BaseHeteroGraph(kBipartiteMetaGraph), adj_(csr) {}
explicit CSR(const aten::CSRMatrix& csr) : BaseHeteroGraph(kBipartiteMetaGraph), adj_(csr) {}
uint64_t NumVertexTypes() const override {
return 2;
......@@ -345,6 +345,34 @@ class Bipartite::CSR : public BaseHeteroGraph {
return adj_.indices->dtype.bits;
}
CSR AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
CSR ret(
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
CSR CopyTo(const DLContext& ctx) const {
if (Context() == ctx) {
return *this;
} else {
CSR ret(
adj_.num_rows, adj_.num_cols,
adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
bool IsMultigraph() const override {
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
return aten::CSRHasDuplicate(adj_);
......@@ -386,8 +414,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
}
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
}
......@@ -408,8 +436,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
}
EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
......@@ -443,7 +471,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
}
EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so
......@@ -476,7 +504,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
}
DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::CSRGetRowNNZ(adj_, vids);
}
......@@ -518,8 +546,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), 2) << "Number of vertex types mismatch";
CHECK(IsValidIdArray(vids[0])) << "Invalid vertex id array.";
CHECK(IsValidIdArray(vids[1])) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids[0])) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids[1])) << "Invalid vertex id array.";
HeteroSubgraph subg;
const auto& submat = aten::CSRSliceMatrix(adj_, vids[0], vids[1]);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
......@@ -579,7 +607,7 @@ bool Bipartite::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
}
BoolArray Bipartite::HasVertices(dgl_type_t vtype, IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input";
CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices(vtype));
}
......@@ -761,6 +789,37 @@ HeteroGraphPtr Bipartite::CreateFromCSR(
return HeteroGraphPtr(new Bipartite(nullptr, csr, nullptr));
}
HeteroGraphPtr Bipartite::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
} else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<Bipartite>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
return HeteroGraphPtr(new Bipartite(new_incsr, new_outcsr, nullptr));
}
}
HeteroGraphPtr Bipartite::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<Bipartite>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
return HeteroGraphPtr(new Bipartite(new_incsr, new_outcsr, nullptr));
}
Bipartite::Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: BaseHeteroGraph(kBipartiteMetaGraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(GetAny()) << "At least one graph structure should exist.";
......@@ -813,6 +872,18 @@ Bipartite::COOPtr Bipartite::GetCOO() const {
return coo_;
}
aten::CSRMatrix Bipartite::GetInCSRMatrix() const {
return GetInCSR()->adj();
}
aten::CSRMatrix Bipartite::GetOutCSRMatrix() const {
return GetOutCSR()->adj();
}
aten::COOMatrix Bipartite::GetCOOMatrix() const {
return GetCOO()->adj();
}
HeteroGraphPtr Bipartite::GetAny() const {
if (in_csr_) {
return in_csr_;
......
......@@ -7,12 +7,14 @@
#ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_
#include <dgl/graph_interface.h>
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
#include <dgl/lazy.h>
#include <dgl/array.h>
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include "../c_api_common.h"
namespace dgl {
......@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph {
/*! \brief edge group type */
static constexpr dgl_type_t kEType = 0;
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override {
return 2;
}
......@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph {
int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids);
private:
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */
CSRPtr GetInCSR() const;
......@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph {
/*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const;
/*! \return Return the in-edge CSR in the matrix form */
aten::CSRMatrix GetInCSRMatrix() const;
/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetOutCSRMatrix() const;
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const;
private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;
......
......@@ -16,8 +16,8 @@ namespace dgl {
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes,
bool multigraph): is_multigraph_(multigraph) {
CHECK(IsValidIdArray(src_ids));
CHECK(IsValidIdArray(dst_ids));
CHECK(aten::IsValidIdArray(src_ids));
CHECK(aten::IsValidIdArray(dst_ids));
this->AddVertices(num_nodes);
num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
......@@ -66,8 +66,8 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
......@@ -92,7 +92,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -113,8 +113,8 @@ bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
// O(E*k) pretty slow
BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
......@@ -201,8 +201,8 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
// O(E*k) pretty slow
EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
int64_t i, j;
......@@ -247,7 +247,7 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
}
EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
......@@ -291,7 +291,7 @@ EdgeArray Graph::InEdges(dgl_id_t vid) const {
// O(E)
EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
......@@ -337,7 +337,7 @@ EdgeArray Graph::OutEdges(dgl_id_t vid) const {
// O(E)
EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
......@@ -409,7 +409,7 @@ EdgeArray Graph::Edges(const std::string &order) const {
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
......@@ -424,7 +424,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const {
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
......@@ -438,7 +438,7 @@ DegreeArray Graph::OutDegrees(IdArray vids) const {
}
Subgraph Graph::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> edges;
......@@ -468,7 +468,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
}
Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
const auto len = eids->shape[0];
std::vector<dgl_id_t> nodes;
const int64_t* eid_data = static_cast<int64_t*>(eids->data);
......
......@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
}
IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
CHECK(IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(IsValidIdArray(query)) << "Invalid query id array.";
CHECK(aten::IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(aten::IsValidIdArray(query)) << "Invalid query id array.";
const auto parent_len = parent_vids->shape[0];
const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
......
......@@ -114,17 +114,33 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
}
// create num verts per type
num_verts_per_type_.resize(meta_graph_->NumVertices(), -1);
for (dgl_type_t vtype = 0; vtype < meta_graph_->NumVertices(); ++vtype) {
for (dgl_type_t etype : meta_graph->OutEdgeVec(vtype)) {
const auto nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType);
if (num_verts_per_type_[vtype] < 0) {
num_verts_per_type_[vtype] = nv;
} else {
CHECK_EQ(num_verts_per_type_[vtype], nv)
<< "Mismatch number of vertices for vertex type " << vtype;
}
}
num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
EdgeArray etype_array = meta_graph->Edges();
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i];
size_t nv;
// # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType);
if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv;
else
CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType);
if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv;
else
CHECK_EQ(num_verts_per_type_[dsttype], nv)
<< "Mismatch number of vertices for vertex type " << dsttype;
}
}
......@@ -140,7 +156,7 @@ bool HeteroGraph::IsMultigraph() const {
}
BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input";
CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices(vtype));
}
......@@ -192,7 +208,7 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
......@@ -202,7 +218,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
......@@ -213,7 +229,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateHeteroGraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
......@@ -226,20 +242,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetMetaGraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->meta_graph());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetRelationGraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddVertices")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
......@@ -247,7 +263,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddVertices")
hg->AddVertices(vtype, num);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdge")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -256,7 +272,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdge")
hg->AddEdge(etype, src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdges")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -265,51 +281,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdges")
hg->AddEdges(etype, src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroClear")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
hg->Clear();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroContext")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->Context();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumBits")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->NumBits();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroIsMultigraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsMultigraph();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroIsReadonly")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsReadonly();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumVertices")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
*rv = static_cast<int64_t>(hg->NumVertices(vtype));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumEdges")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = static_cast<int64_t>(hg->NumEdges(etype));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertex")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
......@@ -317,7 +333,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertex")
*rv = hg->HasVertex(vtype, vid);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertices")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
......@@ -325,7 +341,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertices")
*rv = hg->HasVertices(vtype, vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgeBetween")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -334,7 +350,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgeBetween")
*rv = hg->HasEdgeBetween(etype, src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgesBetween")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -343,7 +359,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgesBetween")
*rv = hg->HasEdgesBetween(etype, src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroPredecessors")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -351,7 +367,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroPredecessors")
*rv = hg->Predecessors(etype, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSuccessors")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -359,7 +375,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSuccessors")
*rv = hg->Successors(etype, src);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeId")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -368,7 +384,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeId")
*rv = hg->EdgeId(etype, src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeIds")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -378,7 +394,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeIds")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroFindEdges")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -387,7 +403,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroFindEdges")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_1")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -396,7 +412,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_1")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_2")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -405,7 +421,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_2")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_1")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -414,7 +430,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_1")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_2")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -423,7 +439,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_2")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdges")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -432,7 +448,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdges")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegree")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -440,7 +456,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegree")
*rv = static_cast<int64_t>(hg->InDegree(etype, vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegrees")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -448,7 +464,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegrees")
*rv = hg->InDegrees(etype, vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegree")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -456,7 +472,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegree")
*rv = static_cast<int64_t>(hg->OutDegree(etype, vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegrees")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -464,7 +480,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegrees")
*rv = hg->OutDegrees(etype, vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetAdj")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
......@@ -474,7 +490,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetAdj")
hg->GetAdj(etype, transpose, fmt));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroVertexSubgraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> vids = args[1];
......@@ -488,7 +504,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroVertexSubgraph")
*rv = HeteroSubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeSubgraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> eids = args[1];
......@@ -505,13 +521,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeSubgraph")
// HeteroSubgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetGraph")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
*rv = HeteroGraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_verts;
......@@ -521,7 +537,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
*rv = induced_verts;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_edges;
......@@ -531,4 +547,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
*rv = induced_edges;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int device_type = args[1];
int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
} // namespace dgl
......@@ -59,9 +59,9 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids));
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
......@@ -69,9 +69,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids));
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
......@@ -79,9 +79,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids));
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0];
......@@ -98,9 +98,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name): is_multigraph_(is_multigraph),
shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids));
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0];
......@@ -140,7 +140,7 @@ EdgeArray CSR::OutEdges(dgl_id_t vid) const {
}
EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so
......@@ -150,7 +150,7 @@ EdgeArray CSR::OutEdges(IdArray vids) const {
}
DegreeArray CSR::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::CSRGetRowNNZ(adj_, vids);
}
......@@ -161,8 +161,8 @@ bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
}
BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
}
......@@ -192,7 +192,7 @@ EdgeArray CSR::Edges(const std::string &order) const {
}
Subgraph CSR::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
......@@ -272,16 +272,16 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
//
//////////////////////////////////////////////////////////
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
CHECK(IsValidIdArray(src));
CHECK(IsValidIdArray(dst));
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
}
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(src));
CHECK(IsValidIdArray(dst));
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
}
......@@ -301,7 +301,7 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
}
EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
eids};
......@@ -316,7 +316,7 @@ EdgeArray COO::Edges(const std::string &order) const {
}
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
COOPtr subcoo;
IdArray induced_nodes;
if (!preserve_nodes) {
......@@ -379,7 +379,7 @@ COO COO::AsNumBits(uint8_t bits) const {
//////////////////////////////////////////////////////////
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input";
CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices());
}
......
......@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const bool add_self_loop,
const ValueType *probability) {
// process args
CHECK(IsValidIdArray(seed_nodes));
CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
......@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
......@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_edges));
CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0];
......
......@@ -197,7 +197,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len);
......@@ -219,7 +219,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
const bool has_nontree_edge = args[4];
const bool return_labels = args[5];
CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
......
......@@ -10,6 +10,7 @@
#include "./binary_reduce_impl_decl.h"
#include "./utils.h"
#include "../c_api_common.h"
#include "../graph/bipartite.h"
#include "./csr_interface.h"
using namespace dgl::runtime;
......@@ -228,6 +229,31 @@ class ImmutableGraphCSRWrapper : public CSRWrapper {
const ImmutableGraph* gptr_;
};
class BipartiteCSRWrapper : public CSRWrapper {
public:
explicit BipartiteCSRWrapper(const Bipartite* graph) :
gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override {
return gptr_->GetInCSRMatrix();
}
aten::CSRMatrix GetOutCSRMatrix() const override {
return gptr_->GetOutCSRMatrix();
}
DGLContext Context() const override {
return gptr_->Context();
}
int NumBits() const override {
return gptr_->NumBits();
}
private:
const Bipartite* gptr_;
};
} // namespace
......@@ -293,11 +319,32 @@ void BinaryOpReduce(
}
}
// Comes from DGLArgValue::AsObjectRef() that allows argvalue to be either a GraphRef
// or a HeteroGraphRef
#define CSRWRAPPER_SWITCH(argvalue, wrapper, ...) do { \
DGLArgValue argval = (argvalue); \
DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle); \
std::shared_ptr<Object>& sptr = \
*argval.ptr<std::shared_ptr<Object>>(); \
if (ObjectTypeChecker<GraphRef>::Check(sptr.get())) { \
GraphRef g = argval; \
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); \
CHECK_NOTNULL(igptr); \
ImmutableGraphCSRWrapper wrapper(igptr.get()); \
{__VA_ARGS__} \
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \
HeteroGraphRef g = argval; \
auto bgptr = std::dynamic_pointer_cast<Bipartite>(g.sptr()); \
CHECK_NOTNULL(bgptr); \
BipartiteCSRWrapper wrapper(bgptr.get()); \
{__VA_ARGS__} \
} \
} while (0)
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_data = args[5];
......@@ -307,14 +354,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CSRWRAPPER_SWITCH(args[2], wrapper, {
BinaryOpReduce(reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping);
});
});
void BackwardLhsBinaryOpReduce(
const std::string& reducer,
......@@ -370,7 +416,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -382,9 +427,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CSRWRAPPER_SWITCH(args[2], wrapper, {
BackwardLhsBinaryOpReduce(
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
......@@ -392,6 +435,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data);
});
});
void BackwardRhsBinaryOpReduce(
const std::string& reducer,
......@@ -446,7 +490,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -458,9 +501,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CSRWRAPPER_SWITCH(args[2], wrapper, {
BackwardRhsBinaryOpReduce(
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
......@@ -468,6 +509,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
lhs_data, rhs_data, out_data, grad_out_data,
grad_rhs_data);
});
});
void CopyReduce(
const std::string& reducer,
......@@ -493,21 +535,19 @@ void CopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
NDArray in_mapping = args[5];
NDArray out_mapping = args[6];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CSRWRAPPER_SWITCH(args[1], wrapper, {
CopyReduce(reducer, wrapper,
static_cast<binary_op::Target>(target),
in_data, out_data,
in_mapping, out_mapping);
});
});
void BackwardCopyReduce(
const std::string& reducer,
......@@ -542,7 +582,6 @@ void BackwardCopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
......@@ -551,15 +590,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7];
NDArray out_mapping = args[8];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CSRWRAPPER_SWITCH(args[1], wrapper, {
BackwardCopyReduce(
reducer, wrapper, static_cast<binary_op::Target>(target),
in_mapping, out_mapping,
in_data, out_data, grad_out_data,
grad_in_data);
});
});
} // namespace kernel
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment