"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1775647f763f9785a0f06eed7cfaa310b6dc9519"
Commit 52d4535b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Hetero][RFC] Heterogeneous graph Python interfaces & Message Passing (#752)

* moving heterograph index to another file

* node view

* python interfaces

* heterograph init

* bug fixes

* docstring for readonly

* more docstring

* unit tests & lint

* oops

* oops x2

* removed node/edge addition

* addressed comments

* lint

* rw on frames with one node/edge type

* homograph with underlying heterograph demo

* view is not necessary

* bugfix

* replace

* scheduler, builtins not working yet

* moving bipartite.h to header

* moving back bipartite to bipartite.h

* oops

* asbits and copyto for bipartite

* tested update_all and send_and_recv

* lightweight node & edge type retrieval

* oops

* sorry

* removing obsolete code

* oops

* lint

* various bug fixes & more tests

* UDF tests

* multiple type number_of_nodes and number_of_edges

* docstring fixes

* more tests

* going for dict in initialization

* lint

* updated api as per discussions

* lint

* bug

* bugfix

* moving back bipartite impl to cc

* note on views

* fix
parent 66971c1a
...@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index); ...@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index);
*/ */
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// Sparse matrix // Sparse matrix
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
......
...@@ -17,6 +17,7 @@ from .base import ALL ...@@ -17,6 +17,7 @@ from .base import ALL
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
from .traversal import * from .traversal import *
from .transform import * from .transform import *
......
...@@ -49,6 +49,14 @@ class DGLBaseGraph(object): ...@@ -49,6 +49,14 @@ class DGLBaseGraph(object):
""" """
return self._graph.number_of_nodes() return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self): def __len__(self):
"""Return the number of nodes in the graph.""" """Return the number of nodes in the graph."""
return self.number_of_nodes() return self.number_of_nodes()
...@@ -65,6 +73,10 @@ class DGLBaseGraph(object): ...@@ -65,6 +73,10 @@ class DGLBaseGraph(object):
""" """
return self._graph.is_readonly() return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges in the graph. """Return the number of edges in the graph.
...@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph): ...@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index): def _set_msg_index(self, index):
self._msg_index = index self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
......
...@@ -1269,737 +1269,4 @@ def create_graph_index(graph_data, multigraph, readonly): ...@@ -1269,737 +1269,4 @@ def create_graph_index(graph_data, multigraph, readonly):
% type(graph_data)) % type(graph_data))
return gidx return gidx
#############################################################
# Hetero graph
#############################################################
@register_object('graph.HeteroGraph')
class HeteroGraphIndex(ObjectBase):
"""HeteroGraph index object.
Note
----
Do not create GraphIndex directly.
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
return obj
def __getstate__(self):
# TODO
return
def __setstate__(self, state):
# TODO
pass
@property
def meta_graph(self):
"""Meta graph
Returns
-------
GraphIndex
The meta graph.
"""
return _CAPI_DGLHeteroGetMetaGraph(self)
def number_of_ntypes(self):
"""Return number of node types."""
return self.meta_graph.number_of_nodes()
def number_of_etypes(self):
"""Return number of edge types."""
return self.meta_graph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
HeteroGraphIndex
The bipartite graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def add_nodes(self, ntype, num):
"""Add nodes.
Parameters
----------
ntype : int
Node type
num : int
Number of nodes to be added.
"""
_CAPI_DGLHetero(self, int(ntype), int(num))
def add_edge(self, etype, u, v):
"""Add one edge.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))
def add_edges(self, etype, u, v):
"""Add many edges.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
_CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor())
def clear(self):
"""Clear the graph."""
_CAPI_DGLHeteroClear(self)
def ctx(self):
"""Return the context of this graph index.
Returns
-------
DGLContext
The context of the graph.
"""
return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsMultigraph(self))
def is_readonly(self):
"""Return whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsReadonly(self))
def number_of_nodes(self, ntype):
"""Return the number of nodes.
Parameters
----------
ntype : int
Node type
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))
def number_of_edges(self, etype):
"""Return the number of edges.
Parameters
----------
etype : int
Edge type
Returns
-------
int
The number of edges
"""
return _CAPI_DGLHeteroNumEdges(self, int(etype))
def has_node(self, ntype, vid):
"""Return true if the node exists.
Parameters
----------
ntype : int
Node type
vid : int
The nodes
Returns
-------
bool
True if the node exists, False otherwise.
"""
return bool(_CAPI_DGLHeteroHasVertex(self, int(ntype), int(vid)))
def has_nodes(self, ntype, vids):
"""Return true if the nodes exist.
Parameters
----------
ntype : int
Node type
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array))
def has_edge_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists, False otherwise
"""
return bool(_CAPI_DGLHeteroHasEdgeBetween(self, int(etype), int(u), int(v)))
def has_edges_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array))
def predecessors(self, etype, v):
"""Return the predecessors of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v)))
def successors(self, etype, v):
"""Return the successors of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v)))
def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
utils.Index
The edge id array.
"""
return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v)))
def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, etype, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, etype, v):
"""Return the in edges of the node(s).
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroInEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, etype, v):
"""Return the out edges of the node(s).
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroOutEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def edges(self, etype, order=None):
"""Return all the edges
Parameters
----------
etype : int
Edge type
order : string
The order of the returned edges. Currently support:
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if order is None:
order = ""
edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
src = utils.toindex(src)
dst = utils.toindex(dst)
eid = utils.toindex(eid)
return src, dst, eid
def in_degree(self, etype, v):
"""Return the in degree of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLHeteroInDegree(self, int(etype), int(v))
def in_degrees(self, etype, v):
"""Return the in degrees of the nodes.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array))
def out_degree(self, etype, v):
"""Return the out degree of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLHeteroOutDegree(self, int(etype), int(v))
def out_degrees(self, etype, v):
"""Return the out degrees of the nodes.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array))
def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
ctx : context
The context of the returned matrix.
Returns
-------
SparseTensor
The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
# convert to framework-specific sparse matrix
srctype, dsttype = self.meta_graph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle
elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
else:
raise Exception("unknown format")
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
Parameters
----------
induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of
node types in this heterograph.
Returns
-------
SubgraphIndex
The subgraph index.
"""
vids = [nodes.todgltensor() for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids)
def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph.
Parameters
----------
induced_edges : list of utils.Index
Induced edges. The length should be equal to the number of
edge types in this heterograph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
SubgraphIndex
The subgraph index.
"""
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
"""Hetero-subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Returns
-------
HeteroGraphIndex
The subgraph
"""
return _CAPI_DGLHeteroSubgraphGetGraph(self)
@property
def induced_nodes(self):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
Returns
-------
list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret]
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
indptr : utils.Index
CSR indptr.
indices : utils.Index
CSR indices.
edge_ids : utils.Index
Edge shuffle id.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(meta_graph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
----------
meta_graph : GraphIndex
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(meta_graph, rel_graphs)
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
"""Classes for heterogeneous graphs.""" """Classes for heterogeneous graphs."""
from collections import defaultdict
import networkx as nx
import scipy.sparse as ssp
from . import heterograph_index, graph_index
from . import utils
from . import backend as F
from . import init
from .runtime import ir, scheduler, Runtime
from .frame import Frame, FrameRef
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, is_all, DGLError
__all__ = ['DGLHeteroGraph']
# TODO: depending on the progress of unifying DGLGraph and Bipartite, we may or may not
# need the code of heterogeneous graph views.
# pylint: disable=unnecessary-pass # pylint: disable=unnecessary-pass
class DGLBaseHeteroGraph(object): class DGLBaseHeteroGraph(object):
"""Base Heterogeneous graph class. """Base Heterogeneous graph class.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Parameters Parameters
---------- ----------
metagraph : NetworkX MultiGraph or compatible data structure graph : graph index, optional
The set of node types and edge types, as well as the The graph index
source/destination node type of each edge type is specified in the ntypes : list[str]
metagraph. The node type names
The edge types are specified as edge keys on the NetworkX MultiGraph. etypes : list[str]
The node types and edge types must be strings. The edge type names
number_of_nodes_by_type : dict[str, int] _ntypes_invmap, _etypes_invmap, _view_ntype_idx, _view_etype_idx :
Number of nodes for each node type. Internal arguments
edge_connections_by_type : dict
Specifies how edges would connect nodes of the source type to nodes of
the destination type in the following form:
{edge_type: edge_specifier}
where ``edge_type`` is a triplet of
(source_node_type_name,
destination_node_type_name,
edge_type_name)
and ``edge_specifier`` can be either of the following:
* (source_node_id_tensor, destination_node_id_tensor)
* ``source_node_id_tensor`` and ``destination_node_id_tensor`` are
IDs within the source and destination node type respectively.
* source node id and destination node id are both in their own ID space.
That is, source nodes and destination nodes may have the same ID,
but they are different nodes if they belong to different node types.
* scipy.sparse.matrix
By default, the rows represent the destination of an edge, and the
column represents the source.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> import networkx as nx
>>> metagraph = nx.MultiGraph([
... ('user', 'user', 'follows'),
... ('user', 'game', 'plays'),
... ('developer', 'game', 'develops')])
>>> g = DGLBaseHeteroGraph(
... metagraph=metagraph,
... number_of_nodes_by_type={'user': 4, 'game': 2, 'developer': 2},
... edge_connections_by_type={
... # Alice follows Bob and Bob follows Carol
... ('user', 'user', 'follows'): ([0, 1], [1, 2]),
... # Alice and Bob play Tetris and Bob and Carol play Minecraft
... ('user', 'game', 'plays'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... # Nintendo develops Tetris and Mojang develops Minecraft
... ('developer', 'game', 'develops'): ([0, 1], [0, 1])})
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
def __init__( def __init__(self, graph, ntypes, etypes,
self, _ntypes_invmap=None, _etypes_invmap=None,
metagraph, _view_ntype_idx=None, _view_etype_idx=None):
number_of_nodes_by_type,
edge_connections_by_type):
super(DGLBaseHeteroGraph, self).__init__() super(DGLBaseHeteroGraph, self).__init__()
def __getitem__(self, key): self._graph = graph
"""Returns a view on the heterogeneous graph with given node/edge self._ntypes = ntypes
type: self._etypes = etypes
# inverse mapping from ntype str to int
self._ntypes_invmap = _ntypes_invmap or \
{ntype: i for i, ntype in enumerate(ntypes)}
# inverse mapping from etype str to int
self._etypes_invmap = _etypes_invmap or \
{etype: i for i, etype in enumerate(etypes)}
* If ``key`` is a str, it returns a heterogeneous subgraph induced # Indicates which node/edge type (int) it is viewing.
from nodes of type ``key``. self._view_ntype_idx = _view_ntype_idx
* If ``key`` is a pair of str (type_A, type_B), it returns a self._view_etype_idx = _view_etype_idx
heterogeneous subgraph induced from the union of both node types.
* If ``key`` is a triplet of str
(src_type_name, dst_type_name, edge_type_name) self._cache = {}
It returns a heterogeneous subgraph induced from the edges with def _create_view(self, ntype_idx, etype_idx):
source type name ``src_type_name``, destination type name return DGLBaseHeteroGraph(
``dst_type_name``, and edge type name ``edge_type_name``. self._graph, self._ntypes, self._etypes,
self._ntypes_invmap, self._etypes_invmap,
ntype_idx, etype_idx)
The view would share the frames with the parent graph; any @property
modifications on one's frames would reflect on the other. def is_node_type_view(self):
"""Whether this is a node type view of a heterograph."""
Note that the subgraph itself is not materialized until someone return self._view_ntype_idx is not None
queries the subgraph structure. This implies that calling computation
methods such as
g['user'].update_all(...) @property
def is_edge_type_view(self):
"""Whether this is an edge type view of a heterograph."""
return self._view_etype_idx is not None
would not create a subgraph of users. @property
def is_view(self):
"""Whether this is a node/view of a heterograph."""
return self.is_node_type_view or self.is_edge_type_view
Parameters @property
---------- def all_node_types(self):
key : str or tuple """Return the list of node types of the entire heterograph."""
See above return self._ntypes
Returns @property
------- def all_edge_types(self):
DGLBaseHeteroGraphView """Return the list of edge types of the entire heterograph."""
The induced subgraph view. return self._etypes
"""
pass
@property @property
def metagraph(self): def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph.""" """Return the metagraph as networkx.MultiDiGraph.
pass
The nodes are labeled with node type names.
def number_of_nodes(self): The edges have their keys holding the edge type names.
"""Return the number of nodes in the graph. """
nx_graph = self._graph.metagraph.to_networkx()
nx_return_graph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
etype = self._etypes[nx_graph.edges[u_v]['id']]
srctype = self._ntypes[u_v[0]]
dsttype = self._ntypes[u_v[1]]
assert etype[0] == srctype
assert etype[2] == dsttype
nx_return_graph.add_edge(srctype, dsttype, etype[1])
return nx_return_graph
def _endpoint_types(self, etype):
"""Return the source and destination node type (int) of given edge
type (int)."""
return self._graph.metagraph.find_edge(etype)
def _node_types(self):
if self.is_node_type_view:
return [self._view_ntype_idx]
elif self.is_edge_type_view:
srctype_idx, dsttype_idx = self._endpoint_types(self._view_etype_idx)
return [srctype_idx, dsttype_idx] if srctype_idx != dsttype_idx else [srctype_idx]
else:
return range(len(self._ntypes))
def node_types(self):
"""Return the list of node types appearing in the current view.
Returns Returns
------- -------
int list[str]
The number of nodes List of node types
"""
pass
def __len__(self): Examples
"""Return the number of nodes in the graph.""" --------
pass Getting all node types.
>>> g.node_types()
# TODO: REVIEW ['user', 'game', 'developer']
def add_nodes(self, num, node_type, data=None):
"""Add multiple new nodes of the same node type Getting all node types appearing in the subgraph induced by "users"
(which should only yield "user").
>>> g['user'].node_types()
['user']
The node types appearing in subgraph induced by "plays" relationship,
which should only give "user" and "game".
>>> g['plays'].node_types()
['user', 'game']
"""
ntypes = self._node_types()
if isinstance(ntypes, range):
# assuming that the range object always covers the entire node type list
return self._ntypes
else:
return [self._ntypes[i] for i in ntypes]
def _edge_types(self):
if self.is_node_type_view:
etype_indices = self._graph.metagraph.edge_id(
self._view_ntype_idx, self._view_ntype_idx)
return etype_indices
elif self.is_edge_type_view:
return [self._view_etype_idx]
else:
return range(len(self._etypes))
def edge_types(self):
"""Return the list of edge types appearing in the current view.
Parameters Returns
---------- -------
num : int list[str]
Number of nodes to be added. List of edge types
node_type : str
Type of the added nodes. Must appear in the metagraph.
data : dict, optional
Feature data of the added nodes.
Examples Examples
-------- --------
The variable ``g`` is constructed from the example in Getting all edge types.
DGLBaseHeteroGraph. >>> g.edge_types()
['follows', 'plays', 'develops']
>>> g['game'].number_of_nodes() Getting all edge types appearing in subgraph induced by "users".
2 >>> g['user'].edge_types()
>>> g.add_nodes(3, 'game') # add 3 new games ['follows']
>>> g['game'].number_of_nodes()
5 The edge types appearing in subgraph induced by "plays" relationship,
which should only give "plays".
>>> g['plays'].edge_types()
['plays']
""" """
pass etypes = self._edge_types()
if isinstance(etypes, range):
return self._etypes
else:
return [self._etypes[i] for i in etypes]
@property
@utils.cached_member('_cache', '_current_ntype_idx')
def _current_ntype_idx(self):
"""Checks the uniqueness of node type in the view and get the index
of that node type.
# TODO: REVIEW This allows reading/writing node frame data.
def add_edge(self, u, v, utype, vtype, etype, data=None): """
"""Add an edge of ``etype`` between u of type ``utype`` and v of type node_types = self._node_types()
``vtype``. assert len(node_types) == 1, "only available for subgraphs with one node type"
return node_types[0]
Parameters @property
---------- @utils.cached_member('_cache', '_current_etype_idx')
u : int def _current_etype_idx(self):
The source node ID of type ``utype``. Must exist in the graph. """Checks the uniqueness of edge type in the view and get the index
v : int of that edge type.
The destination node ID of type ``vtype``. Must exist in the
graph.
utype : str
The source node type name. Must exist in the metagraph.
vtype : str
The destination node type name. Must exist in the metagraph.
etype : str
The edge type name. Must exist in the metagraph.
data : dict, optional
Feature data of the added edge.
Examples This allows reading/writing edge frame data and message passing routines.
-------- """
The variable ``g`` is constructed from the example in edge_types = self._edge_types()
DGLBaseHeteroGraph. assert len(edge_types) == 1, "only available for subgraphs with one edge type"
return edge_types[0]
>>> g['user', 'game', 'plays'].number_of_edges() @property
4 @utils.cached_member('_cache', '_current_srctype_idx')
>>> g.add_edge(2, 0, 'user', 'game', 'plays') def _current_srctype_idx(self):
>>> g['user', 'game', 'plays'].number_of_edges() """Checks the uniqueness of edge type in the view and get the index
5 of the source type.
This allows reading/writing edge frame data and message passing routines.
""" """
pass srctype_idx, _ = self._endpoint_types(self._current_etype_idx)
return srctype_idx
def add_edges(self, u, v, utype, vtype, etype, data=None): @property
"""Add multiple edges of ``etype`` between list of source nodes ``u`` @utils.cached_member('_cache', '_current_dsttype_idx')
of type ``utype`` and list of destination nodes ``v`` of type def _current_dsttype_idx(self):
``vtype``. A single edge is added between every pair of ``u[i]`` and """Checks the uniqueness of edge type in the view and get the index
``v[i]``. of the destination type.
This allows reading/writing edge frame data and message passing routines.
"""
_, dsttype_idx = self._endpoint_types(self._current_etype_idx)
return dsttype_idx
def number_of_nodes(self, ntype):
"""Return the number of nodes of the given type in the heterograph.
Parameters Parameters
---------- ----------
u : list, tensor ntype : str
The source node IDs of type ``utype``. Must exist in the graph. The node type
v : list, tensor
The destination node IDs of type ``vtype``. Must exist in the Returns
graph. -------
utype : str int
The source node type name. Must exist in the metagraph. The number of nodes
vtype : str
The destination node type name. Must exist in the metagraph.
etype : str
The edge type name. Must exist in the metagraph.
data : dict, optional
Feature data of the added edge.
Examples Examples
-------- --------
The variable ``g`` is constructed from the example in >>> g['user'].number_of_nodes()
DGLBaseHeteroGraph. 3
>>> g['user', 'game', 'plays'].number_of_edges()
4
>>> g.add_edges([0, 2], [1, 0], 'user', 'game', 'plays')
>>> g['user', 'game', 'plays'].number_of_edges()
6
""" """
pass return self._graph.number_of_nodes(self._ntypes_invmap[ntype])
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_srctype_idx)
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_dsttype_idx)
@property @property
def is_multigraph(self): def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise. """True if the graph is a multigraph, False otherwise.
""" """
pass assert not self.is_view, 'not supported on views'
return self._graph.is_multigraph()
@property @property
def is_readonly(self): def is_readonly(self):
"""True if the graph is readonly, False otherwise. """True if the graph is readonly, False otherwise.
""" """
pass return self._graph.is_readonly()
def number_of_edges(self): def _number_of_edges(self):
"""Return the number of edges in the graph. """Return number of edges in the current view (only used for scheduler)"""
return self._graph.number_of_edges(self._current_etype_idx)
def number_of_edges(self, etype):
"""Return the number of edges of the given type in the heterograph.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
Returns Returns
------- -------
int int
The number of edges The number of edges
"""
pass
def has_node(self, vid):
"""Return True if the graph contains node `vid`.
Only works if the graph has one node type. For multiple types, Examples
query with --------
>>> g.number_of_edges(('user', 'plays', 'game'))
.. code:: 4
"""
return self._graph.number_of_edges(self._etypes_invmap[etype])
g['vtype'].has_node(vid) def has_node(self, ntype, vid):
"""Return True if the graph contains node `vid` of type `ntype`.
Parameters Parameters
---------- ----------
ntype : str
The node type.
vid : int vid : int
The node ID. The node ID.
...@@ -311,53 +314,26 @@ class DGLBaseHeteroGraph(object): ...@@ -311,53 +314,26 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
>>> g['user'].has_node(0) >>> g.has_node('user', 0)
True True
>>> g['user'].has_node(4) >>> g.has_node('user', 4)
False False
Equivalently,
>>> 0 in g['user']
True
See Also See Also
-------- --------
has_nodes has_nodes
""" """
pass return self._graph.has_node(self._ntypes_invmap[ntype], vid)
def __contains__(self, vid):
"""Return True if the graph contains node `vid`.
Only works if the graph has one node type. For multiple types,
query with
.. code::
vid in g['vtype']
Examples
--------
>>> 0 in g['user']
True
"""
pass
def has_nodes(self, vids): def has_nodes(self, ntype, vids):
"""Return a 0-1 array ``a`` given the node ID array ``vids``. """Return a 0-1 array ``a`` given the node ID array ``vids``.
``a[i]`` is 1 if the graph contains node ``vids[i]``, 0 otherwise. ``a[i]`` is 1 if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Only works if the graph has one node type. For multiple types,
query with
.. code::
g['vtype'].has_nodes(vids)
Parameters Parameters
---------- ----------
ntype : str
The node type.
vid : list or tensor vid : list or tensor
The array of node IDs. The array of node IDs.
...@@ -370,27 +346,24 @@ class DGLBaseHeteroGraph(object): ...@@ -370,27 +346,24 @@ class DGLBaseHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g['user'].has_nodes([0, 1, 2, 3, 4]) >>> g.has_nodes('user', [0, 1, 2, 3, 4])
tensor([1, 1, 1, 0, 0]) tensor([1, 1, 1, 0, 0])
See Also See Also
-------- --------
has_node has_node
""" """
pass vids = utils.toindex(vids)
rst = self._graph.has_nodes(self._ntypes_invmap[ntype], vids)
def has_edge_between(self, u, v): return rst.tousertensor()
"""Return True if the edge (u, v) is in the graph.
Only works if the graph has one edge type. For multiple types,
query with
.. code:: def has_edge_between(self, etype, u, v):
"""Return True if the edge (u, v) of type ``etype`` is in the graph.
g['srctype', 'dsttype', 'edgetype'].has_edge_between(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int u : int
The node ID of source type. The node ID of source type.
v : int v : int
...@@ -404,34 +377,29 @@ class DGLBaseHeteroGraph(object): ...@@ -404,34 +377,29 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
Check whether Alice plays Tetris Check whether Alice plays Tetris
>>> g['user', 'game', 'plays'].has_edge_between(0, 1) >>> g.has_edge_between(('user', 'plays', 'game'), 0, 1)
True True
And whether Alice plays Minecraft And whether Alice plays Minecraft
>>> g['user', 'game', 'plays'].has_edge_between(0, 2) >>> g.has_edge_between(('user', 'plays', 'game'), 0, 2)
False False
See Also See Also
-------- --------
has_edges_between has_edges_between
""" """
pass return self._graph.has_edge_between(self._etypes_invmap[etype], u, v)
def has_edges_between(self, u, v): def has_edges_between(self, etype, u, v):
"""Return a 0-1 array `a` given the source node ID array `u` and """Return a 0-1 array ``a`` given the source node ID array ``u`` and
destination node ID array `v`. destination node ID array ``v``.
`a[i]` is 1 if the graph contains edge `(u[i], v[i])`, 0 otherwise. ``a[i]`` is 1 if the graph contains edge ``(u[i], v[i])`` of type ``etype``, 0 otherwise.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].has_edges_between(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor u : list, tensor
The node ID array of source type. The node ID array of source type.
v : list, tensor v : list, tensor
...@@ -446,31 +414,29 @@ class DGLBaseHeteroGraph(object): ...@@ -446,31 +414,29 @@ class DGLBaseHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g['user', 'game', 'plays'].has_edges_between([0, 0], [1, 2]) >>> g.has_edges_between(('user', 'plays', 'game'), [0, 0], [1, 2])
tensor([1, 0]) tensor([1, 0])
See Also See Also
-------- --------
has_edge_between has_edge_between
""" """
pass u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(self._etypes_invmap[etype], u, v)
return rst.tousertensor()
def predecessors(self, v): def predecessors(self, etype, v):
"""Return the predecessors of node `v` in the graph with the same """Return the predecessors of node `v` in the graph with the same
edge type. edge type.
Node `u` is a predecessor of `v` if an edge `(u, v)` exist in the Node `u` is a predecessor of `v` if an edge `(u, v)` exist in the
graph. graph.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].predecessors(v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node of destination type. The node of destination type.
...@@ -484,7 +450,7 @@ class DGLBaseHeteroGraph(object): ...@@ -484,7 +450,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Query who plays Tetris: Query who plays Tetris:
>>> g['user', 'game', 'plays'].predecessors(0) >>> g.predecessors(('user', 'plays', 'game'), 0)
tensor([0, 1]) tensor([0, 1])
This indicates User #0 (Alice) and User #1 (Bob). This indicates User #0 (Alice) and User #1 (Bob).
...@@ -493,38 +459,33 @@ class DGLBaseHeteroGraph(object): ...@@ -493,38 +459,33 @@ class DGLBaseHeteroGraph(object):
-------- --------
successors successors
""" """
pass return self._graph.predecessors(self._etypes_invmap[etype], v).tousertensor()
def successors(self, v): def successors(self, etype, v):
"""Return the successors of node `v` in the graph with the same edge """Return the successors of node `v` in the graph with the same edge
type. type.
Node `u` is a successor of `v` if an edge `(v, u)` exist in the Node `u` is a successor of `v` if an edge `(v, u)` exist in the
graph. graph.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].successors(v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node of source type. The node of source type.
Returns Returns
------- -------
tensor tensor
Array of successor node IDs if destination node type. Array of successor node IDs of destination node type.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
Asks which game Alice plays: Asks which game Alice plays:
>>> g['user', 'game', 'plays'].successors(0) >>> g.successors(('user', 'plays', 'game'), 0)
tensor([0]) tensor([0])
This indicates Game #0 (Tetris). This indicates Game #0 (Tetris).
...@@ -533,21 +494,19 @@ class DGLBaseHeteroGraph(object): ...@@ -533,21 +494,19 @@ class DGLBaseHeteroGraph(object):
-------- --------
predecessors predecessors
""" """
pass return self._graph.successors(self._etypes_invmap[etype], v).tousertensor()
def edge_id(self, u, v, force_multi=False): def edge_id(self, etype, u, v, force_multi=False):
"""Return the edge ID, or an array of edge IDs, between source node """Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`. `u` and destination node `v`.
Only works if the graph has one edge type. For multiple types, Only works if the graph has one edge type. For multiple types,
query with query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_id(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int u : int
The node ID of source type. The node ID of source type.
v : int v : int
...@@ -567,28 +526,27 @@ class DGLBaseHeteroGraph(object): ...@@ -567,28 +526,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the edge ID of "Bob plays Tetris" Find the edge ID of "Bob plays Tetris"
>>> g['user', 'game', 'plays'].edge_id(1, 0) >>> g.edge_id(('user', 'plays', 'game'), 1, 0)
1 1
See Also See Also
-------- --------
edge_ids edge_ids
""" """
pass idx = self._graph.edge_id(self._etypes_invmap[etype], u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
def edge_ids(self, u, v, force_multi=False): def edge_ids(self, etype, u, v, force_multi=False):
"""Return all edge IDs between source node array `u` and destination """Return all edge IDs between source node array `u` and destination
node array `v`. node array `v`.
Only works if the graph has one edge type. For multiple types, Only works if the graph has one edge type. For multiple types,
query with query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor u : list, tensor
The node ID array of source type. The node ID array of source type.
v : list, tensor v : list, tensor
...@@ -616,29 +574,30 @@ class DGLBaseHeteroGraph(object): ...@@ -616,29 +574,30 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft". Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft".
>>> g['user', 'game', 'plays'].edge_ids([0, 1], [0, 1]) >>> g.edge_ids(('user', 'plays', 'game'), [0, 1], [0, 1])
tensor([0, 2]) tensor([0, 2])
See Also See Also
-------- --------
edge_id edge_id
""" """
pass u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self._etypes_invmap[etype], u, v)
if force_multi or self._graph.is_multigraph():
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
def find_edges(self, eid): def find_edges(self, etype, eid):
"""Given an edge ID array, return the source and destination node ID """Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`. ID for edge `eid[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
eid : list, tensor eid : list, tensor
The edge ID array. The edge ID array.
...@@ -654,23 +613,20 @@ class DGLBaseHeteroGraph(object): ...@@ -654,23 +613,20 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the user and game of gameplay #0 and #2: Find the user and game of gameplay #0 and #2:
>>> g['user', 'game', 'plays'].find_edges([0, 2]) >>> g.find_edges(('user', 'plays', 'game'), [0, 2])
(tensor([0, 1]), tensor([0, 1])) (tensor([0, 1]), tensor([0, 1]))
""" """
pass eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(self._etypes_invmap[etype], eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, v, form='uv'): def in_edges(self, etype, v, form='uv'):
"""Return the inbound edges of the node(s). """Return the inbound edges of the node(s).
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor v : int, list, tensor
The node(s) of destination type. The node(s) of destination type.
form : str, optional form : str, optional
...@@ -696,23 +652,27 @@ class DGLBaseHeteroGraph(object): ...@@ -696,23 +652,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the gameplay IDs of game #0 (Tetris) Find the gameplay IDs of game #0 (Tetris)
>>> g['user', 'game', 'plays'].in_edges(0, 'eid') >>> g.in_edges(('user', 'plays', 'game'), 0, 'eid')
tensor([0, 1]) tensor([0, 1])
""" """
pass v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(self._etypes_invmap[etype], v)
def out_edges(self, v, form='uv'): if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def out_edges(self, etype, v, form='uv'):
"""Return the outbound edges of the node(s). """Return the outbound edges of the node(s).
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor v : int, list, tensor
The node(s) of source type. The node(s) of source type.
form : str, optional form : str, optional
...@@ -738,23 +698,27 @@ class DGLBaseHeteroGraph(object): ...@@ -738,23 +698,27 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the gameplay IDs of user #0 (Alice) Find the gameplay IDs of user #0 (Alice)
>>> g['user', 'game', 'plays'].out_edges(0, 'eid') >>> g.out_edges(('user', 'plays', 'game'), 0, 'eid')
tensor([0]) tensor([0])
""" """
pass v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(self._etypes_invmap[etype], v)
def all_edges(self, form='uv', order=None): if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def all_edges(self, etype, form='uv', order=None):
"""Return all the edges. """Return all the edges.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
form : str, optional form : str, optional
The return form. Currently support: The return form. Currently support:
...@@ -785,57 +749,57 @@ class DGLBaseHeteroGraph(object): ...@@ -785,57 +749,57 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the user-game pairs for all gameplays: Find the user-game pairs for all gameplays:
>>> g['user', 'game', 'plays'].all_edges('uv') >>> g.all_edges(('user', 'plays', 'game'), 'uv')
(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1])) (tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1]))
""" """
pass src, dst, eid = self._graph.edges(self._etypes_invmap[etype], order)
if form == 'all':
def in_degree(self, v): return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def in_degree(self, etype, v):
"""Return the in-degree of node ``v``. """Return the in-degree of node ``v``.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node ID of destination type. The node ID of destination type.
Returns Returns
------- -------
etype : (str, str, str)
The source-edge-destination type triplet
int int
The in-degree. The in-degree.
Examples Examples
-------- --------
Find how many users are playing Game #0 (Tetris): Find how many users are playing Game #0 (Tetris):
>>> g['user', 'game', 'plays'].in_degree(0) >>> g.in_degree(('user', 'plays', 'game'), 0)
2 2
See Also See Also
-------- --------
in_degrees in_degrees
""" """
pass return self._graph.in_degree(self._etypes_invmap[etype], v)
def in_degrees(self, v=ALL): def in_degrees(self, etype, v=ALL):
"""Return the array `d` of in-degrees of the node array `v`. """Return the array `d` of in-degrees of the node array `v`.
`d[i]` is the in-degree of node `v[i]`. `d[i]` is the in-degree of node `v[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor, optional. v : list, tensor, optional.
The node ID array of destination type. Default is to return the The node ID array of destination type. Default is to return the
degrees of all the nodes. degrees of all the nodes.
...@@ -850,27 +814,28 @@ class DGLBaseHeteroGraph(object): ...@@ -850,27 +814,28 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find how many users are playing Game #0 and #1 (Tetris and Minecraft): Find how many users are playing Game #0 and #1 (Tetris and Minecraft):
>>> g['user', 'game', 'plays'].in_degrees([0, 1]) >>> g.in_degrees(('user', 'plays', 'game'), [0, 1])
tensor([2, 2]) tensor([2, 2])
See Also See Also
-------- --------
in_degree in_degree
""" """
pass etype_idx = self._etypes_invmap[etype]
_, dsttype_idx = self._endpoint_types(etype_idx)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(dsttype_idx)))
else:
v = utils.toindex(v)
return self._graph.in_degrees(etype_idx, v).tousertensor()
def out_degree(self, v): def out_degree(self, etype, v):
"""Return the out-degree of node `v`. """Return the out-degree of node `v`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node ID of source type. The node ID of source type.
...@@ -882,29 +847,24 @@ class DGLBaseHeteroGraph(object): ...@@ -882,29 +847,24 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
Find how many games User #0 Alice is playing Find how many games User #0 Alice is playing
>>> g['user', 'game', 'plays'].out_degree(0) >>> g.out_degree(('user', 'plays', 'game'), 0)
1 1
See Also See Also
-------- --------
out_degrees out_degrees
""" """
pass return self._graph.out_degree(self._etypes_invmap[etype], v)
def out_degrees(self, v=ALL): def out_degrees(self, etype, v=ALL):
"""Return the array `d` of out-degrees of the node array `v`. """Return the array `d` of out-degrees of the node array `v`.
`d[i]` is the out-degree of node `v[i]`. `d[i]` is the out-degree of node `v[i]`.
Only works if the graph has one edge type. For multiple types,
query with
.. code::
g['srctype', 'dsttype', 'edgetype'].edge_ids(u, v)
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor v : list, tensor
The node ID array of source type. Default is to return the degrees The node ID array of source type. Default is to return the degrees
of all the nodes. of all the nodes.
...@@ -919,58 +879,413 @@ class DGLBaseHeteroGraph(object): ...@@ -919,58 +879,413 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find how many games User #0 and #1 (Alice and Bob) are playing Find how many games User #0 and #1 (Alice and Bob) are playing
>>> g['user', 'game', 'plays'].out_degrees([0, 1]) >>> g.out_degrees(('user', 'plays', 'game'), [0, 1])
tensor([1, 2]) tensor([1, 2])
See Also See Also
-------- --------
out_degree out_degree
""" """
etype_idx = self._etypes_invmap[etype]
srctype_idx, _ = self._endpoint_types(etype_idx)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(srctype_idx)))
else:
v = utils.toindex(v)
return self._graph.out_degrees(etype_idx, v).tousertensor()
def bipartite_from_edge_list(u, v, num_src=None, num_dst=None):
"""Create a bipartite graph component of a heterogeneous graph with a
list of edges.
Parameters
----------
u, v : list[int]
List of source and destination node IDs.
num_src : int, optional
The number of nodes of source type.
By default, the value is the maximum of the source node IDs in the
edge list plus 1.
num_dst : int, optional
The number of nodes of destination type.
By default, the value is the maximum of the destination node IDs in
the edge list plus 1.
"""
num_src = num_src or (max(u) + 1)
num_dst = num_dst or (max(v) + 1)
u = utils.toindex(u)
v = utils.toindex(v)
return heterograph_index.create_bipartite_from_coo(num_src, num_dst, u, v)
def bipartite_from_scipy(spmat, with_edge_id=False):
"""Create a bipartite graph component of a heterogeneous graph with a
scipy sparse matrix.
Parameters
----------
spmat : scipy sparse matrix
The bipartite graph matrix whose rows represent sources and columns
represent destinations.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
"""
spmat = spmat.tocsr()
num_src, num_dst = spmat.shape
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
return heterograph_index.create_bipartite_from_csr(num_src, num_dst, indptr, indices, data)
class DGLHeteroGraph(DGLBaseHeteroGraph):
"""Base heterogeneous graph class.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Parameters
----------
graph_data :
The graph data. It can be one of the followings:
* (nx.MultiDiGraph, dict[str, list[tuple[int, int]]])
* (nx.MultiDiGraph, dict[str, scipy.sparse.matrix])
The first element is the metagraph of the heterogeneous graph, as a
networkx directed graph. Its nodes represent the node types, and
its edges represent the edge types. The edge type name should be
stored as edge keys.
The second element is a mapping from edge type to edge list. The
edge list can be either a list of (u, v) pairs, or a scipy sparse
matrix whose rows represents sources and columns represents
destinations. The edges will be added in the (source, destination)
order.
node_frames : dict[str, dict[str, Tensor]]
The node frames for each node type
edge_frames : dict[str, dict[str, Tensor]]
The edge frames for each edge type
multigraph : bool
Whether the heterogeneous graph is a multigraph.
readonly : bool
Whether the heterogeneous graph is readonly.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> mg = nx.MultiDiGraph([('user', 'user', 'follows'),
... ('user', 'game', 'plays'),
... ('developer', 'game', 'develops')])
>>> g = DGLHeteroGraph(
... mg, {
... 'follows': [(0, 1), (1, 2)],
... 'plays': [(0, 0), (1, 0), (1, 1), (2, 1)],
... 'develops': [(0, 0), (1, 1)]})
Then one can query the graph structure as follows:
>>> g['user'].number_of_nodes()
3
>>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees() # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g['develops'].in_edges(0) # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
Notes
-----
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument
def __init__(
self,
graph_data=None,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True,
_view_ntype_idx=None,
_view_etype_idx=None):
assert readonly, "Only readonly heterogeneous graphs are supported"
# Creating a view of another graph?
if isinstance(graph_data, DGLHeteroGraph):
super(DGLHeteroGraph, self).__init__(
graph_data._graph, graph_data._ntypes, graph_data._etypes,
graph_data._ntypes_invmap, graph_data._etypes_invmap,
graph_data._view_ntype_idx, graph_data._view_etype_idx)
self._node_frames = graph_data._node_frames
self._edge_frames = graph_data._edge_frames
self._msg_frames = graph_data._msg_frames
self._msg_indices = graph_data._msg_indices
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
return
if isinstance(graph_data, tuple):
metagraph, edges_by_type = graph_data
if not isinstance(metagraph, nx.MultiDiGraph):
raise TypeError('Metagraph should be networkx.MultiDiGraph')
# create metagraph graph index
srctypes, dsttypes, etypes = [], [], []
ntypes = []
ntypes_invmap = {}
etypes_invmap = {}
for srctype, dsttype, etype in metagraph.edges(keys=True):
srctypes.append(srctype)
dsttypes.append(dsttype)
etypes_invmap[(srctype, etype, dsttype)] = len(etypes_invmap)
etypes.append((srctype, etype, dsttype))
if srctype not in ntypes_invmap:
ntypes_invmap[srctype] = len(ntypes_invmap)
ntypes.append(srctype)
if dsttype not in ntypes_invmap:
ntypes_invmap[dsttype] = len(ntypes_invmap)
ntypes.append(dsttype)
srctypes = [ntypes_invmap[srctype] for srctype in srctypes]
dsttypes = [ntypes_invmap[dsttype] for dsttype in dsttypes]
metagraph_index = graph_index.create_graph_index(
list(zip(srctypes, dsttypes)), None, True) # metagraph is always immutable
# create base bipartites
bipartites = []
num_nodes = defaultdict(int)
# count the number of nodes for each type
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
num_src, num_dst = edges.shape
elif isinstance(edges, list):
u, v = zip(*edges)
num_src = max(u) + 1
num_dst = max(v) + 1
else:
raise TypeError('unknown edge list type %s' % type(edges))
num_nodes[srctype] = max(num_nodes[srctype], num_src)
num_nodes[dsttype] = max(num_nodes[dsttype], num_dst)
# create actual objects
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
bipartite = bipartite_from_scipy(edges)
elif isinstance(edges, list):
u, v = zip(*edges)
bipartite = bipartite_from_edge_list(
u, v, num_nodes[srctype], num_nodes[dsttype])
bipartites.append(bipartite)
hg_index = heterograph_index.create_heterograph(metagraph_index, bipartites)
super(DGLHeteroGraph, self).__init__(hg_index, ntypes, etypes)
else:
raise TypeError('Unrecognized graph data type %s' % type(graph_data))
# node and edge frame
if node_frames is None:
self._node_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
for i in range(len(self._ntypes))]
else:
self._node_frames = node_frames
if edge_frames is None:
self._edge_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
for i in range(len(self._etypes))]
else:
self._edge_frames = edge_frames
# message indicators
self._msg_indices = [None] * len(self._etypes)
self._msg_frames = []
for i in range(len(self._etypes)):
frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
def _create_view(self, ntype_idx, etype_idx):
return DGLHeteroGraph(
graph_data=self, _view_ntype_idx=ntype_idx, _view_etype_idx=etype_idx)
def _get_msg_index(self):
if self._msg_indices[self._current_etype_idx] is None:
self._msg_indices[self._current_etype_idx] = utils.zero_index(
size=self._graph.number_of_edges(self._current_etype_idx))
return self._msg_indices[self._current_etype_idx]
def _set_msg_index(self, index):
self._msg_indices[self._current_etype_idx] = index
def __getitem__(self, key):
if key in self._etypes_invmap:
return self._create_view(None, self._etypes_invmap[key])
else:
raise KeyError(key)
@property
def _node_frame(self):
# overrides DGLGraph._node_frame
return self._node_frames[self._current_ntype_idx]
@property
def _edge_frame(self):
# overrides DGLGraph._edge_frame
return self._edge_frames[self._current_etype_idx]
@property
def _src_frame(self):
# overrides DGLGraph._src_frame
return self._node_frames[self._current_srctype_idx]
@property
def _dst_frame(self):
# overrides DGLGraph._dst_frame
return self._node_frames[self._current_dsttype_idx]
@property
def _msg_frame(self):
# overrides DGLGraph._msg_frame
return self._msg_frames[self._current_etype_idx]
def add_nodes(self, node_type, num, data=None):
"""Add multiple new nodes of the same node type
Parameters
----------
node_type : str
Type of the added nodes. Must appear in the metagraph.
num : int
Number of nodes to be added.
data : dict, optional
Feature data of the added nodes.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['game'].number_of_nodes()
2
>>> g.add_nodes(3, 'game') # add 3 new games
>>> g['game'].number_of_nodes()
5
"""
pass pass
def add_edge(self, etype, u, v, data=None):
"""Add an edge of ``etype`` between u of the source node type, and v
of the destination node type..
class DGLBaseHeteroGraphView(DGLBaseHeteroGraph): Parameters
"""View on a heterogeneous graph, constructed from ----------
DGLBaseHeteroGraph.__getitem__(). etype : (str, str, str)
The source-edge-destination type triplet
u : int
The source node ID of type ``utype``. Must exist in the graph.
v : int
The destination node ID of type ``vtype``. Must exist in the
graph.
data : dict, optional
Feature data of the added edge.
It is semantically the same as a subgraph, except that Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
* The subgraph itself is not materialized until the user explicitly >>> g['plays'].number_of_edges()
queries the subgraph structure (e.g. calling ``in_edges``, but not 4
``update_all``). >>> g.add_edge(2, 0, 'plays')
>>> g['plays'].number_of_edges()
5
""" """
pass pass
def add_edges(self, u, v, etype, data=None):
class DGLHeteroGraph(DGLBaseHeteroGraph): """Add multiple edges of ``etype`` between list of source nodes ``u``
"""Base heterogeneous graph class. and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
The graph stores nodes, edges and also their (type-specific) features.
Heterogeneous graphs are by default multigraphs.
Parameters Parameters
---------- ----------
metagraph, number_of_nodes_by_type, edge_connections_by_type : u : list, tensor
See DGLBaseHeteroGraph The source node IDs of type ``utype``. Must exist in the graph.
node_frame : dict[str, FrameRef], optional v : list, tensor
Node feature storage per type The destination node IDs of type ``vtype``. Must exist in the
edge_frame : dict[str, FrameRef], optional graph.
Edge feature storage per type etype : (str, str, str)
readonly : bool, optional The source-edge-destination type triplet
Whether the graph structure is read-only (default: False) data : dict, optional
Feature data of the added edge.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['plays'].number_of_edges()
4
>>> g.add_edges([0, 2], [1, 0], 'plays')
>>> g['plays'].number_of_edges()
6
""" """
# pylint: disable=unused-argument pass
def __init__(
self,
metagraph,
number_of_nodes_by_type,
edge_connections_by_type,
node_frame=None,
edge_frame=None,
readonly=False):
super(DGLHeteroGraph, self).__init__(
metagraph, number_of_nodes_by_type, edge_connections_by_type)
def from_networkx( def from_networkx(
self, self,
...@@ -1020,10 +1335,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1020,10 +1335,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
pass pass
def node_attr_schemes(self, ntype): def node_attr_schemes(self, ntype):
"""Return the node feature schemes for a given node type. """Return the node feature schemes.
Each feature scheme is a named tuple that stores the shape and data type Each feature scheme is a named tuple that stores the shape and data type
of the node feature of the node feature.
Parameters Parameters
---------- ----------
...@@ -1034,150 +1349,90 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1034,150 +1349,90 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
------- -------
dict of str to schemes dict of str to schemes
The schemes of node feature columns. The schemes of node feature columns.
Examples
--------
The following uses PyTorch backend.
>>> g.ndata['user']['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)}
""" """
pass return self._node_frames[self._ntypes_invmap[ntype]].schemes
def edge_attr_schemes(self, etype): def edge_attr_schemes(self, etype):
"""Return the edge feature schemes for a given edge type. """Return the edge feature schemes.
Each feature scheme is a named tuple that stores the shape and data type Each feature scheme is a named tuple that stores the shape and data type
of the edge feature of the edge feature.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str] etype : (str, str, str)
The edge type, characterized by a triplet of source type name, The source-edge-destination type triplet
destination type name, and edge type name.
Returns Returns
------- -------
dict of str to schemes dict of str to schemes
The schemes of node feature columns. The schemes of node feature columns.
"""
pass
def set_n_initializer(self, ntype, initializer, field=None):
"""Set the initializer for empty node features of given type.
Initializer is a callable that returns a tensor given the shape, data type
and device context.
When a subset of the nodes are assigned a new feature, initializer is
used to create feature for rest of the nodes.
Parameters
----------
ntype : str
The node type name.
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
"""
pass
def set_e_initializer(self, etype, initializer, field=None):
"""Set the initializer for empty edge features of given type.
Initializer is a callable that returns a tensor given the shape, data Examples
type and device context. --------
The following uses PyTorch backend.
When a subset of the edges are assigned a new feature, initializer is
used to create feature for rest of the edges.
Parameters >>> g.edata['user', 'plays', 'game']['h'] = torch.randn(4, 4)
---------- >>> g.edge_attr_schemes(('user', 'plays', 'game'))
etype : tuple[str, str, str] {'h': Scheme(shape=(4,), dtype=torch.float32)}
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
""" """
pass return self._edge_frames[self._etypes_invmap[etype]].schemes
@property @property
def nodes(self): def nodes(self):
"""Return a node view that can used to set/get feature data of a """Return a node view that can used to set/get feature data of a
single node type. single node type.
Notes
-----
An error is raised if the graph contains multiple node types. Use
g[ntype]
to select nodes with type ``ntype``.
Examples Examples
-------- --------
To set features of User #0 and #2 in a heterogeneous graph: To set features of User #0 and #2 in a heterogeneous graph:
>>> g['user'].nodes[[0, 2]].data['h'] = torch.zeros(2, 5) >>> g.nodes['user'][[0, 2]].data['h'] = torch.zeros(2, 5)
""" """
pass return HeteroNodeView(self)
@property @property
def ndata(self): def ndata(self):
"""Return the data view of all the nodes of a single node type. """Return the data view of all the nodes of a single node type.
Notes
-----
An error is raised if the graph contains multiple node types. Use
g[ntype]
to select nodes with type ``ntype``.
Examples Examples
-------- --------
To set features of games in a heterogeneous graph: To set features of games in a heterogeneous graph:
>>> g['game'].ndata['h'] = torch.zeros(2, 5) >>> g.ndata['game']['h'] = torch.zeros(2, 5)
""" """
pass return HeteroNodeDataView(self)
@property @property
def edges(self): def edges(self):
"""Return an edges view that can used to set/get feature data of a """Return an edges view that can used to set/get feature data of a
single edge type. single edge type.
Notes
-----
An error is raised if the graph contains multiple edge types. Use
g[src_type, dst_type, edge_type]
to select edges with type ``(src_type, dst_type, edge_type)``.
Examples Examples
-------- --------
To set features of gameplays #1 (Bob -> Tetris) and #3 (Carol -> To set features of gameplays #1 (Bob -> Tetris) and #3 (Carol ->
Minecraft) in a heterogeneous graph: Minecraft) in a heterogeneous graph:
>>> g['user', 'game', 'plays'].edges[[1, 3]].data['h'] = torch.zeros(2, 5) >>> g.edges['user', 'plays', 'game'][[1, 3]].data['h'] = torch.zeros(2, 5)
""" """
pass return HeteroEdgeView(self)
@property @property
def edata(self): def edata(self):
"""Return the data view of all the edges of a single edge type. """Return the data view of all the edges of a single edge type.
Notes
-----
An error is raised if the graph contains multiple edge types. Use
g[src_type, dst_type, edge_type]
to select edges with type ``(src_type, dst_type, edge_type)``.
Examples Examples
-------- --------
>>> g['developer', 'game', 'develops'].edata['h'] = torch.zeros(2, 5) >>> g.edata['developer', 'develops', 'game']['h'] = torch.zeros(2, 5)
""" """
pass return HeteroEdgeDataView(self)
def set_n_repr(self, data, ntype, u=ALL, inplace=False): def set_n_repr(self, ntype, data, u=ALL, inplace=False):
"""Set node(s) representation of a single node type. """Set node(s) representation of a single node type.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
...@@ -1190,16 +1445,32 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1190,16 +1445,32 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
ntype : str
The node type
data : dict of tensor data : dict of tensor
Node representation. Node representation.
ntype : str
Node type.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool inplace : bool
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
pass ntype = self._ntypes_invmap[ntype]
if is_all(u):
num_nodes = self._graph.number_of_nodes(ntype)
else:
u = utils.toindex(u)
num_nodes = len(u)
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
if is_all(u):
for key, val in data.items():
self._node_frames[ntype][key] = val
else:
self._node_frames[ntype].update_rows(u, data, inplace=inplace)
def get_n_repr(self, ntype, u=ALL): def get_n_repr(self, ntype, u=ALL):
"""Get node(s) representation of a single node type. """Get node(s) representation of a single node type.
...@@ -1209,7 +1480,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1209,7 +1480,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
ntype : str ntype : str
Node type. The node type
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
...@@ -1218,7 +1489,14 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1218,7 +1489,14 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict dict
Representation dict from feature name to feature tensor. Representation dict from feature name to feature tensor.
""" """
pass if len(self.node_attr_schemes(ntype)) == 0:
return dict()
ntype_idx = self._ntypes_invmap[ntype]
if is_all(u):
return dict(self._node_frames[ntype_idx])
else:
u = utils.toindex(u)
return self._node_frames[ntype_idx].select_rows(u)
def pop_n_repr(self, ntype, key): def pop_n_repr(self, ntype, key):
"""Get and remove the specified node repr of a given node type. """Get and remove the specified node repr of a given node type.
...@@ -1226,7 +1504,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1226,7 +1504,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
ntype : str ntype : str
The node type. The node type
key : str key : str
The attribute name. The attribute name.
...@@ -1235,9 +1513,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1235,9 +1513,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor Tensor
The popped representation The popped representation
""" """
pass ntype = self._ntypes_invmap[ntype]
return self._node_frames[ntype].pop(key)
def set_e_repr(self, data, etype, edges=ALL, inplace=False): def set_e_repr(self, etype, data, edges=ALL, inplace=False):
"""Set edge(s) representation of a single edge type. """Set edge(s) representation of a single edge type.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
...@@ -1249,11 +1528,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1249,11 +1528,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
data : tensor or dict of tensor data : tensor or dict of tensor
Edge representation. Edge representation.
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
edges : edges edges : edges
Edges can be either Edges can be either
...@@ -1265,16 +1543,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1265,16 +1543,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
inplace : bool inplace : bool
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
pass etype_idx = self._etypes_invmap[etype]
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
# sanity check
if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(data))
if is_all(eid):
num_edges = self._graph.number_of_edges(etype_idx)
else:
eid = utils.toindex(eid)
num_edges = len(eid)
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.'
' Got %d and %d instead.' % (nfeats, num_edges))
# set
if is_all(eid):
# update column
for key, val in data.items():
self._edge_frames[etype_idx][key] = val
else:
# update row
self._edge_frames[etype_idx].update_rows(eid, data, inplace=inplace)
def get_e_repr(self, etype, edges=ALL): def get_e_repr(self, etype, edges=ALL):
"""Get edge(s) representation. """Get edge(s) representation.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str] etype : (str, str, str)
The edge type, characterized by a triplet of source type name, The source-edge-destination type triplet
destination type name, and edge type name.
edges : edges edges : edges
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
...@@ -1284,16 +1596,34 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1284,16 +1596,34 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict dict
Representation dict Representation dict
""" """
pass etype_idx = self._etypes_invmap[etype]
if len(self.edge_attr_schemes(etype)) == 0:
return dict()
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
if is_all(eid):
return dict(self._edge_frames[etype_idx])
else:
eid = utils.toindex(eid)
return self._edge_frames[etype_idx].select_rows(eid)
def pop_e_repr(self, etype, key): def pop_e_repr(self, etype, key):
"""Get and remove the specified edge repr of a single edge type. """Get and remove the specified edge repr of a single edge type.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str] etype : (str, str, str)
The edge type, characterized by a triplet of source type name, The source-edge-destination type triplet
destination type name, and edge type name.
key : str key : str
The attribute name. The attribute name.
...@@ -1302,7 +1632,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1302,7 +1632,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor Tensor
The popped representation The popped representation
""" """
pass etype = self._etypes_invmap[etype]
self._edge_frames[etype].pop(key)
def register_message_func(self, func): def register_message_func(self, func):
"""Register global message function for each edge type provided. """Register global message function for each edge type provided.
...@@ -1314,17 +1645,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1314,17 +1645,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[etype, callable] func : callable
Message function on the edge. The function should be Message function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
See Also See Also
-------- --------
send send
...@@ -1333,7 +1657,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1333,7 +1657,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
push push
update_all update_all
""" """
pass raise NotImplementedError
def register_reduce_func(self, func): def register_reduce_func(self, func):
"""Register global message reduce function for each edge type provided. """Register global message reduce function for each edge type provided.
...@@ -1345,17 +1669,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1345,17 +1669,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[etype, callable] func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
See Also See Also
-------- --------
recv recv
...@@ -1364,7 +1681,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1364,7 +1681,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
pull pull
update_all update_all
""" """
pass raise NotImplementedError
def register_apply_node_func(self, func): def register_apply_node_func(self, func):
"""Register global node apply function for each node type provided. """Register global node apply function for each node type provided.
...@@ -1376,21 +1693,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1376,21 +1693,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[str, callable] func : callable
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``func`` is not a
dict, it will throw an error.
See Also See Also
-------- --------
apply_nodes apply_nodes
register_apply_edge_func register_apply_edge_func
""" """
pass raise NotImplementedError
def register_apply_edge_func(self, func): def register_apply_edge_func(self, func):
"""Register global edge apply function for each edge type provided. """Register global edge apply function for each edge type provided.
...@@ -1400,23 +1712,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1400,23 +1712,16 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[etype, callable] func : callable
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
See Also See Also
-------- --------
apply_edges apply_edges
register_apply_node_func register_apply_node_func
""" """
pass raise NotImplementedError
def apply_nodes(self, func, v=ALL, inplace=False): def apply_nodes(self, func, v=ALL, inplace=False):
"""Apply the function on the nodes with the same type to update their """Apply the function on the nodes with the same type to update their
...@@ -1426,40 +1731,35 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1426,40 +1731,35 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[str, callable], or None func : dict[str, callable] or None
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
v : dict[str, int or iterable of int or tensor], optional
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``func`` is not a
dict, it will throw an error.
v : int, iterable of int, tensor, dict, optional
The (type-specific) node (ids) on which to apply ``func``. The (type-specific) node (ids) on which to apply ``func``.
If ``func`` is not a dict, then ``v`` must not be a dict.
If ``func`` is a dict, then ``v`` must either be
* ALL: for computing on all nodes with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the nodes to be updated for each type.
inplace : bool, optional inplace : bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Examples Examples
-------- --------
>>> g['user'].ndata['h'] = torch.ones(3, 5) >>> g.ndata['user']['h'] = torch.ones(3, 5)
>>> g['user'].apply_nodes(lambda x: {'h': x * 2}) >>> g.apply_nodes({'user': lambda nodes: {'h': nodes.data['h'] * 2}})
>>> g['user'].ndata['h'] >>> g.ndata['user']['h']
tensor([[2., 2., 2., 2., 2.], tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]) [2., 2., 2., 2., 2.]])
>>> g.apply_nodes({'user': lambda x: {'h': x * 2}})
>>> g['user'].ndata['h']
tensor([[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]])
""" """
pass for ntype, nfunc in func.items():
if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
else:
v_ntype = utils.toindex(v[ntype])
with ir.prog() as prog:
scheduler.schedule_apply_nodes(
graph=self._create_view(self._ntypes_invmap[ntype], None),
v=v_ntype,
apply_func=nfunc,
inplace=inplace)
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, inplace=False): def apply_edges(self, func, edges=ALL, inplace=False):
"""Apply the function on the edges with the same type to update their """Apply the function on the edges with the same type to update their
...@@ -1469,42 +1769,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1469,42 +1769,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : callable, dict[etype, callable], or None func : dict[(str, str, str), callable] or None
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
edges : dict[(str, str, str), any valid edge specification], optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
edges : any valid edge specification, dict, optional
Edges on which to apply ``func``. See :func:`send` for valid Edges on which to apply ``func``. See :func:`send` for valid
edge specification. edge specification.
If ``func`` is not a dict, then ``edges`` must not be a dict.
If ``func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the edges to be updated for each type.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Examples Examples
-------- --------
>>> g['user', 'game', 'plays'].edata['h'] = torch.ones(3, 5) >>> g.edata['user', 'plays', 'game']['h'] = torch.ones(4, 5)
>>> g['user', 'game', 'plays'].apply_edges(lambda x: {'h': x * 2}) >>> g.apply_edges(
>>> g['user', 'game', 'plays'].edata['h'] ... {('user', 'plays', 'game'): lambda edges: {'h': edges.data['h'] * 2}})
>>> g.edata['user', 'plays', 'game']['h']
tensor([[2., 2., 2., 2., 2.], tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]) [2., 2., 2., 2., 2.]])
>>> g.apply_edges({('user', 'game', 'plays'): lambda x: {'h': x * 2}})
tensor([[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]])
""" """
pass for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False): def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
"""Group the edges by nodes and apply the function of the grouped """Group the edges by nodes and apply the function of the grouped
...@@ -1515,33 +1823,46 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1515,33 +1823,46 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
---------- ----------
group_by : str group_by : str
Specify how to group edges. Expected to be either 'src' or 'dst' Specify how to group edges. Expected to be either 'src' or 'dst'
func : callable, dict[etype, callable] func : dict[(str, str, str), callable]
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should
be (bucket_size, degrees, *feature_shape), and be (bucket_size, degrees, *feature_shape), and
return the dict with values of the same shapes. return the dict with values of the same shapes.
edges : dict[(str, str, str), valid edges type], optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``func`` is not a
dict, it will throw an error.
edges : valid edges type, dict, optional
Edges on which to group and apply ``func``. See :func:`send` for valid Edges on which to group and apply ``func``. See :func:`send` for valid
edges type. Default is all the edges. edges type. Default is all the edges.
If ``func`` is not a dict, then ``edges`` must not be a dict.
If ``func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in ``func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``func``, indicating the edges to be updated for each type.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
pass if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst")
for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx)
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
group_by=group_by,
inplace=inplace)
Runtime.run(prog)
# TODO: REVIEW
def send(self, edges=ALL, message_func=None): def send(self, edges=ALL, message_func=None):
"""Send messages along the given edges with the same edge type. """Send messages along the given edges with the same edge type.
...@@ -1553,7 +1874,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1553,7 +1874,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
* ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids. * ``int iterable`` / ``tensor`` : Specify multiple edges using their edge ids.
* ``pair of int iterable`` / ``pair of tensors`` : * ``pair of int iterable`` / ``pair of tensors`` :
Specify multiple edges using their endpoints. Specify multiple edges using their endpoints.
* a dict of all the above, if ``message_func`` is a dict.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].send(edges, message_func)
The UDF returns messages on the edges and can be later fetched in The UDF returns messages on the edges and can be later fetched in
the destination node's ``mailbox``. Receiving will consume the messages. the destination node's ``mailbox``. Receiving will consume the messages.
...@@ -1564,34 +1891,43 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1564,34 +1891,43 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
edges : valid edges type, dict, optional edges : valid edges type, optional
Edges on which to apply ``message_func``. Default is sending along all Edges on which to apply ``message_func``. Default is sending along all
the edges. the edges.
message_func : callable
If ``message_func`` is not a dict, then ``edges`` must not be a dict.
If ``message_func`` is a dict, then ``edges`` must either be
* ALL: for computing on all edges with the given types in
``message_func``.
* a dict of int, iterable of int, or tensors, with the same keys
as ``message_func``, indicating the edges to be updated for each
type.
message_func : callable, dict[etype, callable]
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
Notes Notes
----- -----
On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
along all edges between :math:`u` and :math:`v`. along all edges between :math:`u` and :math:`v`.
""" """
pass assert not utils.is_dict_like(message_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(self._current_etype_idx)))
u, v, _ = self._graph.edges(self._current_etype_idx)
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
if len(eid) == 0:
# no edge to be triggered
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)
def recv(self, def recv(self,
v=ALL, v=ALL,
...@@ -1615,53 +1951,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1615,53 +1951,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
The provided UDF maybe called multiple times so it is recommended to provide The provided UDF maybe called multiple times so it is recommended to provide
function with no side effect. function with no side effect.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].recv(v, reduce_func, apply_node_func, inplace)
Parameters Parameters
---------- ----------
v : int, container or tensor, dict, optional v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes. The node(s) to be updated. Default is receiving all the nodes.
reduce_func : callable, optional
If ``apply_node_func`` is not a dict, then ``v`` must not be a
dict.
If ``apply_node_func`` is a dict, then ``v`` must either be
* ALL: for computing on all nodes with the given types in
``apply_node_func``.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
reduce_func : callable, dict[etype, callable], optional
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable]
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``v``, the node types in ``apply_node_func``,
and the destination types in ``reduce_func`` must be the same.
""" """
pass assert not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert reduce_func is not None
if is_all(v):
v = F.arange(0, self._graph.number_of_nodes(self._current_dsttype_idx))
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
# no vertex to be triggered.
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self,
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def send_and_recv(self, def send_and_recv(self,
edges, edges,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default", apply_node_func=None,
inplace=False): inplace=False):
"""Send messages along edges with the same edge type, and let destinations """Send messages along edges with the same edge type, and let destinations
receive them. receive them.
...@@ -1673,66 +2009,65 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1673,66 +2009,65 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
``recv(self, dst, reduce_func, apply_node_func)``, where ``dst`` ``recv(self, dst, reduce_func, apply_node_func)``, where ``dst``
are the destinations of the ``edges``. are the destinations of the ``edges``.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].send_and_recv(edges, message_func, reduce_func, apply_node_func, inplace)
Parameters Parameters
---------- ----------
edges : valid edges type edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid Edges on which to apply ``func``. See :func:`send` for valid
edges type. edges type.
message_func : callable, optional
If the functions are not dicts, then ``edges`` must not be a dict.
If the functions are dicts, then ``edges`` must either be
* ALL: for computing on all edges with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the edges
to be updated for each type.
message_func : callable, dict[etype, callable], optional
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the destination type of ``edges``, the node types in
``apply_node_func``, and the destination types in ``reduce_func``
must be the same.
* the edge type of ``edges``, ``message_func`` and ``reduce_func``
must also be the same.
""" """
pass assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
if len(u) == 0:
# no edges to be triggered
return
with ir.prog() as prog:
scheduler.schedule_snr(graph=self,
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def pull(self, def pull(self,
v, v,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default", apply_node_func=None,
inplace=False): inplace=False):
"""Pull messages from the node(s)' predecessors and then update their features. """Pull messages from the node(s)' predecessors and then update their features.
...@@ -1744,126 +2079,102 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1744,126 +2079,102 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
by the column initializer (see :func:`set_n_initializer`). The feature shapes and by the column initializer (see :func:`set_n_initializer`). The feature shapes and
dtypes will be inferred. dtypes will be inferred.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].pull(v, message_func, reduce_func, apply_node_func, inplace)
Parameters Parameters
---------- ----------
v : int, container or tensor, dict, optional v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes. The node(s) to be updated. Default is receiving all the nodes.
message_func : callable, optional
If the functions are not dicts, then ``v`` must not be a dict.
If the functions are dicts, then ``v`` must either be
* ALL: for computing on all nodes with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
message_func : callable, dict[etype, callable], optional
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types of ``v``, the node types in ``apply_node_func``,
and the destination types in ``reduce_func`` must be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
""" """
pass assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
v = utils.toindex(v)
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self,
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def push(self, def push(self,
u, u,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default", apply_node_func=None,
inplace=False): inplace=False):
"""Send message from the node(s) to their successors and update them. """Send message from the node(s) to their successors and update them.
Optionally, apply a function to update the node features after receive. Optionally, apply a function to update the node features after receive.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].push(e, message_func, reduce_func, apply_node_func, inplace)
Parameters Parameters
---------- ----------
u : int, container or tensor, dict u : int, container or tensor
The node(s) to push messages out. The node(s) to push messages out.
message_func : callable, optional
If the functions are not dicts, then ``v`` must not be a dict.
If the functions are dicts, then ``v`` must either be
* ALL: for computing on all nodes with the given types in the
functions.
* a dict of int, iterable of int, or tensors, indicating the nodes
to be updated for each type.
message_func : callable, dict[etype, callable], optional
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``apply_node_func`` and the destination types in
``reduce_func`` must be the same.
* the source types of ``message_func`` and the node types of ``u`` must
be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
""" """
pass assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
u = utils.toindex(u)
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self,
u=u,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def update_all(self, def update_all(self,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default"): apply_node_func=None):
"""Send messages through all edges and update all nodes. """Send messages through all edges and update all nodes.
Optionally, apply a function to update the node features after receive. Optionally, apply a function to update the node features after receive.
...@@ -1872,64 +2183,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1872,64 +2183,53 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
``send(self, self.edges(), message_func)`` and ``send(self, self.edges(), message_func)`` and
``recv(self, self.nodes(), reduce_func, apply_node_func)``. ``recv(self, self.nodes(), reduce_func, apply_node_func)``.
Only works if the graph has one edge type. For multiple types,
use
.. code::
g['edgetype'].update_all(message_func, reduce_func, apply_node_func)
Parameters Parameters
---------- ----------
message_func : callable, dict[etype, callable], optional message_func : callable, optional
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
If a dict is provided, the functions will be applied according to
edge type.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``message_func`` is
not a dict, it will throw an error.
reduce_func : callable, dict[etype, callable], optional
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
If a dict is provided, the messages will be aggregated onto the
nodes by the edge type of the message.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
If the graph has more than one edge type and ``reduce_func`` is not
a dict, it will throw an error.
apply_node_func : callable, dict[str, callable], optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
If a dict is provided, the functions will be applied according to
node type.
If the graph has more than one node type and ``apply_func`` is not
a dict, it will throw an error.
Notes
-----
If the graph is heterogeneous (i.e. having more than one node/edge
type),
* the node types in ``apply_node_func`` and the destination types in
``reduce_func`` must be the same.
* the edge type of ``message_func`` and ``reduce_func`` must also be
the same.
""" """
pass assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
Runtime.run(prog)
# TODO should we support this?
def prop_nodes(self, def prop_nodes(self,
nodes_generator, nodes_generator,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default"): apply_node_func=None):
"""Node propagation in heterogeneous graph is not supported. """Node propagation in heterogeneous graph is not supported.
""" """
raise NotImplementedError('not supported') raise NotImplementedError('not supported')
# TODO should we support this?
def prop_edges(self, def prop_edges(self,
edges_generator, edges_generator,
message_func="default", message_func=None,
reduce_func="default", reduce_func=None,
apply_node_func="default"): apply_node_func=None):
"""Edge propagation in heterogeneous graph is not supported. """Edge propagation in heterogeneous graph is not supported.
""" """
raise NotImplementedError('not supported') raise NotImplementedError('not supported')
...@@ -2167,8 +2467,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2167,8 +2467,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
""" """
pass pass
# TODO: replace this after implementing frame
# pylint: disable=useless-super-delegation
def __repr__(self): def __repr__(self):
pass return super(DGLHeteroGraph, self).__repr__()
# pylint: disable=abstract-method # pylint: disable=abstract-method
class DGLHeteroSubGraph(DGLHeteroGraph): class DGLHeteroSubGraph(DGLHeteroGraph):
......
"""Module for heterogeneous graph index class definition."""
from __future__ import absolute_import
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
from . import utils
@register_object('graph.HeteroGraph')
class HeteroGraphIndex(ObjectBase):
"""HeteroGraph index object.
Note
----
Do not create GraphIndex directly.
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
obj._cache = {}
return obj
def __getstate__(self):
# TODO
return
def __setstate__(self, state):
# TODO
pass
@property
def metagraph(self):
"""Meta graph
Returns
-------
GraphIndex
The meta graph.
"""
return _CAPI_DGLHeteroGetMetaGraph(self)
def number_of_ntypes(self):
"""Return number of node types."""
return self.metagraph.number_of_nodes()
def number_of_etypes(self):
"""Return number of edge types."""
return self.metagraph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
HeteroGraphIndex
The bipartite graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def add_nodes(self, ntype, num):
"""Add nodes.
Parameters
----------
ntype : int
Node type
num : int
Number of nodes to be added.
"""
_CAPI_DGLHeteroAddVertices(self, int(ntype), int(num))
self.clear_cache()
def add_edge(self, etype, u, v):
"""Add one edge.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))
self.clear_cache()
def add_edges(self, etype, u, v):
"""Add many edges.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
_CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor())
self.clear_cache()
def clear(self):
"""Clear the graph."""
_CAPI_DGLHeteroClear(self)
self._cache.clear()
def ctx(self):
"""Return the context of this graph index.
Returns
-------
DGLContext
The context of the graph.
"""
return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph.
Parameters
----------
etype : int
The edge type.
Returns
-------
int
The number of bits needed.
"""
stype, dtype = self.metagraph.find_edge(etype)
if (self.number_of_edges(etype) >= 0x80000000 or
self.number_of_nodes(stype) >= 0x80000000 or
self.number_of_nodes(dtype) >= 0x80000000):
return 64
else:
return 32
def asbits(self, bits):
"""Transform the graph to a new one with the given number of bits storage.
NOTE: this method only works for immutable graph index
Parameters
----------
bits : int
The number of integer bits (32 or 64)
Returns
-------
HeteroGraphIndex
The graph index stored using the given number of bits.
"""
return _CAPI_DGLHeteroAsNumBits(self, int(bits))
def copy_to(self, ctx):
"""Copy this immutable graph index to the given device context.
NOTE: this method only works for immutable graph index
Parameters
----------
ctx : DGLContext
The target device context.
Returns
-------
HeteroGraphIndex
The graph index on the given device context.
"""
return _CAPI_DGLHeteroCopyTo(self, ctx.device_type, ctx.device_id)
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsMultigraph(self))
def is_readonly(self):
"""Return whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return bool(_CAPI_DGLHeteroIsReadonly(self))
def number_of_nodes(self, ntype):
"""Return the number of nodes.
Parameters
----------
ntype : int
Node type
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))
def number_of_edges(self, etype):
"""Return the number of edges.
Parameters
----------
etype : int
Edge type
Returns
-------
int
The number of edges
"""
return _CAPI_DGLHeteroNumEdges(self, int(etype))
def has_node(self, ntype, vid):
"""Return true if the node exists.
Parameters
----------
ntype : int
Node type
vid : int
The nodes
Returns
-------
bool
True if the node exists, False otherwise.
"""
return bool(_CAPI_DGLHeteroHasVertex(self, int(ntype), int(vid)))
def has_nodes(self, ntype, vids):
"""Return true if the nodes exist.
Parameters
----------
ntype : int
Node type
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array))
def has_edge_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists, False otherwise
"""
return bool(_CAPI_DGLHeteroHasEdgeBetween(self, int(etype), int(u), int(v)))
def has_edges_between(self, etype, u, v):
"""Return true if the edge exists.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array))
def predecessors(self, etype, v):
"""Return the predecessors of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v)))
def successors(self, etype, v):
"""Return the successors of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v)))
def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v.
Parameters
----------
etype : int
Edge type
u : int
The src node.
v : int
The dst node.
Returns
-------
utils.Index
The edge id array.
"""
return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v)))
def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, etype, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
etype : int
Edge type
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, etype, v):
"""Return the in edges of the node(s).
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroInEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, etype, v):
"""Return the out edges of the node(s).
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLHeteroOutEdges_1(self, int(etype), int(v[0]))
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges')
def edges(self, etype, order=None):
"""Return all the edges
Parameters
----------
etype : int
Edge type
order : string
The order of the returned edges. Currently support:
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if order is None:
order = ""
edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
src = utils.toindex(src)
dst = utils.toindex(dst)
eid = utils.toindex(eid)
return src, dst, eid
def in_degree(self, etype, v):
"""Return the in degree of the node.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLHeteroInDegree(self, int(etype), int(v))
def in_degrees(self, etype, v):
"""Return the in degrees of the nodes.
Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array))
def out_degree(self, etype, v):
"""Return the out degree of the node.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLHeteroOutDegree(self, int(etype), int(v))
def out_degrees(self, etype, v):
"""Return the out degrees of the nodes.
Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.
Parameters
----------
etype : int
Edge type
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array))
def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
ctx : context
The context of the returned matrix.
Returns
-------
SparseTensor
The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
# convert to framework-specific sparse matrix
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle
elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
else:
raise Exception("unknown format")
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
Parameters
----------
induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of
node types in this heterograph.
Returns
-------
SubgraphIndex
The subgraph index.
"""
vids = [nodes.todgltensor() for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids)
def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph.
Parameters
----------
induced_edges : list of utils.Index
Induced edges. The length should be equal to the number of
edge types in this heterograph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
SubgraphIndex
The subgraph index.
"""
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite')
def get_bipartite(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device
context.
Note: this internal function is for DGL scheduler use only
Parameters
----------
etype : int, or None
If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type.
ctx : DGLContext
The context of the returned graph.
Returns
-------
HeteroGraphIndex
"""
g = self.get_relation_graph(etype) if etype is not None else self
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype):
"""Return the edge shuffling order when a coo graph is converted to csr format
Parameters
----------
etype : int
The edge type
Returns
-------
tuple of two utils.Index
The first element of the tuple is the shuffle order for outward graph
The second element of the tuple is the shuffle order for inward graph
"""
csr = _CAPI_DGLHeteroGetAdj(self, int(etype), True, "csr")
order = csr(2)
rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr")
rev_order = rev_csr(2)
return utils.toindex(order), utils.toindex(rev_order)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
"""Hetero-subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Returns
-------
HeteroGraphIndex
The subgraph
"""
return _CAPI_DGLHeteroSubgraphGetGraph(self)
@property
def induced_nodes(self):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
Returns
-------
list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret]
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
Parameters
----------
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
indptr : utils.Index
CSR indptr.
indices : utils.Index
CSR indices.
edge_ids : utils.Index
Edge shuffle id.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
----------
metagraph : GraphIndex
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
_init_api("dgl.heterograph_index")
...@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule( ...@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule(
apply_func, apply_func,
u, v, eid, u, v, eid,
group_by, group_by,
var_nf, var_src_nf,
var_dst_nf,
var_ef, var_ef,
var_out): var_out):
"""Create degree bucketing schedule for group_apply_edge """Create degree bucketing schedule for group_apply_edge
...@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule( ...@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule(
Edges to apply Edges to apply
group_by: str group_by: str
If "src", group by u. If "dst", group by v If "src", group by u. If "dst", group by v
var_nf : var.FEAT_DICT var_src_nf : var.FEAT_DICT
The variable for node feature frame. The variable for source feature frame.
var_dst_nf : var.FEAT_DICT
The variable for destination feature frame.
var_ef : var.FEAT_DICT var_ef : var.FEAT_DICT
The variable for edge frame. The variable for edge frame.
var_out : var.FEAT_DICT var_out : var.FEAT_DICT
...@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule( ...@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule(
var_v = var.IDX(v_bkt) var_v = var.IDX(v_bkt)
var_eid = var.IDX(eid_bkt) var_eid = var.IDX(eid_bkt)
# apply edge UDF on each bucket # apply edge UDF on each bucket
fdsrc = ir.READ_ROW(var_nf, var_u) fdsrc = ir.READ_ROW(var_src_nf, var_u)
fddst = ir.READ_ROW(var_nf, var_v) fddst = ir.READ_ROW(var_dst_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid) fdedge = ir.READ_ROW(var_ef, var_eid)
fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var
# save for merge # save for merge
......
...@@ -8,6 +8,8 @@ from .. import backend as F ...@@ -8,6 +8,8 @@ from .. import backend as F
from ..frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -28,6 +30,15 @@ __all__ = [ ...@@ -28,6 +30,15 @@ __all__ = [
"schedule_pull" "schedule_pull"
] ]
def _dispatch(graph, method, *args, **kwargs):
graph_index = graph._graph
if isinstance(graph_index, GraphIndex):
return getattr(graph._graph, method)(*args, **kwargs)
elif isinstance(graph_index, HeteroGraphIndex):
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func): def schedule_send(graph, u, v, eid, message_func):
"""get send schedule """get send schedule
...@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func):
The message function The message function
""" """
var_mf = var.FEAT_DICT(graph._msg_frame) var_mf = var.FEAT_DICT(graph._msg_frame)
var_nf = var.FEAT_DICT(graph._node_frame) var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph._edge_frame)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
...@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func):
v=v, v=v,
eid=eid, eid=eid,
mfunc=message_func, mfunc=message_func,
var_src_nf=var_nf, var_src_nf=var_src_nf,
var_dst_nf=var_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef) var_ef=var_ef)
# write tmp msg back # write tmp msg back
...@@ -83,7 +95,7 @@ def schedule_recv(graph, ...@@ -83,7 +95,7 @@ def schedule_recv(graph,
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
""" """
src, dst, eid = graph._graph.in_edges(recv_nodes) src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes)
if len(eid) > 0: if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero() nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx) eid = eid.get_items(nonzero_idx)
...@@ -96,7 +108,7 @@ def schedule_recv(graph, ...@@ -96,7 +108,7 @@ def schedule_recv(graph,
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace) schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else: else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
# sort and unique the argument # sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
...@@ -105,12 +117,12 @@ def schedule_recv(graph, ...@@ -105,12 +117,12 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid), reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes) recv_nodes)
# apply # apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
# set message indicator to 0 # set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0)) graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero(): if not graph._get_msg_index().has_nonzero():
...@@ -148,7 +160,7 @@ def schedule_snr(graph, ...@@ -148,7 +160,7 @@ def schedule_snr(graph,
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
...@@ -156,11 +168,11 @@ def schedule_snr(graph, ...@@ -156,11 +168,11 @@ def schedule_snr(graph,
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv( adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph.number_of_nodes()) edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame, src_node_frame=graph._src_frame,
dst_node_frame=graph._node_frame, dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame, edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -170,12 +182,12 @@ def schedule_snr(graph, ...@@ -170,12 +182,12 @@ def schedule_snr(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat,
apply_func) apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph, def schedule_update_all(graph,
message_func, message_func,
...@@ -194,27 +206,27 @@ def schedule_update_all(graph, ...@@ -194,27 +206,27 @@ def schedule_update_all(graph,
apply_func: callable apply_func: callable
The apply node function The apply node function
""" """
if graph.number_of_edges() == 0: if graph._number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes # All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes())) nodes = utils.toindex(slice(0, graph._number_of_dst_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False) schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else: else:
eid = utils.toindex(slice(0, graph.number_of_edges())) # ALL eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # ALL recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send + reduce # generate send + reduce
def uv_getter(): def uv_getter():
src, dst, _ = graph._graph.edges('eid') src, dst, _ = _dispatch(graph, 'edges', 'eid')
return var.IDX(src), var.IDX(dst) return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph) adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame, src_node_frame=graph._src_frame,
dst_node_frame=graph._node_frame, dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame, edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -224,9 +236,9 @@ def schedule_update_all(graph, ...@@ -224,9 +236,9 @@ def schedule_update_all(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_dst_nf, final_feat)
def schedule_apply_nodes(graph, def schedule_apply_nodes(graph,
v, v,
...@@ -326,10 +338,12 @@ def schedule_apply_edges(graph, ...@@ -326,10 +338,12 @@ def schedule_apply_edges(graph,
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
# vars # vars
var_nf = var.FEAT_DICT(graph._node_frame) var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph._edge_frame)
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func, var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_nf, var_dst_nf=var_nf, var_ef=var_ef) var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# schedule apply edges # schedule apply edges
...@@ -401,7 +415,7 @@ def schedule_push(graph, ...@@ -401,7 +415,7 @@ def schedule_push(graph,
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
""" """
u, v, eid = graph._graph.out_edges(u) u, v, eid = _dispatch(graph, 'out_edges', u)
if len(eid) == 0: if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled. # All the pushing nodes have no out edges. No computation is scheduled.
return return
...@@ -434,7 +448,7 @@ def schedule_pull(graph, ...@@ -434,7 +448,7 @@ def schedule_pull(graph,
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat # can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier. # directly from pull node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes) u, v, eid = _dispatch(graph, 'in_edges', pull_nodes)
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
if apply_func is not None: if apply_func is not None:
...@@ -443,27 +457,27 @@ def schedule_pull(graph, ...@@ -443,27 +457,27 @@ def schedule_pull(graph,
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes)
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes') var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
num_nodes = graph.number_of_nodes() adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
adj_creator = lambda: spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes) (u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._node_frame, reduced_feat = _gen_send_reduce(graph, graph._src_frame,
graph._node_frame, graph._edge_frame, graph._dst_frame, graph._edge_frame,
message_func, reduce_func, var_eid, message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator, var_pull_nodes, uv_getter, adj_creator,
out_map_creator) out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph, def schedule_group_apply_edge(graph,
u, v, eid, u, v, eid,
...@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph, ...@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph,
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
# vars # vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_out = var.FEAT_DICT(name='new_ef') var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by, db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by,
var_nf, var_ef, var_out) var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
...@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame. # node frame.
# TODO(minjie): should replace this with an IR call to make the program # TODO(minjie): should replace this with an IR call to make the program
# stateless. # stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes))) tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes)))
# vars # vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg') var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
var_nf = var.FEAT_DICT(graph._node_frame, 'nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf')
var_out = var.FEAT_DICT(data=tmpframe) var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list: if rfunc_is_list:
num_nodes = graph.number_of_nodes()
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv( adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), num_nodes) (src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
# using edge map instead of message map because messages are in global # using edge map instead of message map because messages are in global
# message frame # message frame
var_out_map = _build_idx_map(recv_nodes, nbits) var_out_map = _build_idx_map(recv_nodes, nbits)
...@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes, db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes,
var_nf, var_msg, var_out) var_dst_nf, var_msg, var_out)
return var_out return var_out
def _gen_send_reduce( def _gen_send_reduce(
...@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if mfunc_is_list: if mfunc_is_list:
if eid.is_slice(0, graph.number_of_edges()): if eid.is_slice(0, graph._number_of_edges()):
# full graph case # full graph case
res = spmv.build_gidx_and_mapping_graph(graph) res = spmv.build_gidx_and_mapping_graph(graph)
else: else:
num_nodes = graph.number_of_nodes() res = spmv.build_gidx_and_mapping_uv(
res = spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes) (u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
adj, edge_map, _ = res adj, edge_map, _ = res
# create a tmp message frame # create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid))) tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid)))
......
"""Module for SPMV rules.""" """Module for SPMV rules."""
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial
from ..base import DGLError from ..base import DGLError
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .. import ndarray as nd from .. import ndarray as nd
from ..graph_index import from_coo from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph): ...@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph or DGLHeteroGraph
The graph The homogeneous graph, or a bipartite view of the heterogeneous graph.
Returns Returns
------- -------
...@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph): ...@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
gidx = graph._graph gidx = graph._graph
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed() return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_nodes): def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges """Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of The matrix is of shape (len(reduce_nodes), n), where n is the number of
...@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
--------- ---------
edge_tuples : tuple of three utils.Index edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid) A tuple of (u, v, eid)
num_nodes : int num_src, num_dst : int
The number of nodes. The number of source and destination nodes.
Returns Returns
------- -------
...@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
gidx = from_coo(num_nodes, u, v, None, True) gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order() forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor() eid = eid.tousertensor()
nbits = gidx.bits_needed() nbits = gidx.bits_needed(0)
forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits) forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits)
backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits) backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits)
forward_map = F.zerocopy_to_dgl_ndarray(forward_map) forward_map = F.zerocopy_to_dgl_ndarray(forward_map)
...@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
edge_map = utils.CtxCachedObject( edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx), lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx))) nd.array(backward_map, ctx=ctx)))
return gidx.get_immutable_gidx, edge_map, nbits return partial(gidx.get_bipartite, None), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
...@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): ...@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
eid = utils.toindex(eid) eid = utils.toindex(eid)
else: else:
u, v, eid = edge_tuples u, v, eid = edge_tuples
num_nodes = max(graph.layer_size(block_id), graph.layer_size(block_id + 1)) num_src, num_dst = graph.layer_size(block_id), graph.layer_size(block_id + 1)
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_nodes) gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_src, num_dst)
return gidx, edge_map, nbits return gidx, edge_map, nbits
...@@ -242,3 +242,229 @@ class BlockDataView(MutableMapping): ...@@ -242,3 +242,229 @@ class BlockDataView(MutableMapping):
def __repr__(self): def __repr__(self):
data = self._graph._edge_frames[self._flow] data = self._graph._edge_frames[self._flow]
return repr({key : data[key] for key in data}) return repr({key : data[key] for key in data})
class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, ntype):
return HeteroNodeTypeView(self._graph, ntype)
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL))
else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype', '_nodes']
def __init__(self, graph, ntype, nodes):
self._graph = graph
self._ntype = ntype
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val})
def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key)
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __repr__(self):
data = self._graph.get_n_repr(self._ntype)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, etype):
return HeteroEdgeTypeView(self._graph, etype)
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL))
else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges))
def __call__(self):
"""Return the edges."""
return F.arange(0, len(self))
class HeteroEdgeTypeDataView(MutableMapping):
"""The data view class when G.edges[etype][...].data is called.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype', '_edges']
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val})
def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key)
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __repr__(self):
data = self._graph.get_e_repr(self._etype)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
...@@ -52,11 +52,6 @@ enum BoolFlag { ...@@ -52,11 +52,6 @@ enum BoolFlag {
dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<dgl::runtime::NDArray>& vec); const std::vector<dgl::runtime::NDArray>& vec);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*! /*!
* \brief Copy a vector to an int64_t NDArray. * \brief Copy a vector to an int64_t NDArray.
* *
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/base_heterograph.h>
#include "./bipartite.h" #include "./bipartite.h"
#include "../c_api_common.h" #include "../c_api_common.h"
namespace dgl { namespace dgl {
namespace { namespace {
inline GraphPtr CreateBipartiteMetaGraph() { inline GraphPtr CreateBipartiteMetaGraph() {
std::vector<int64_t> row_vec(1, Bipartite::kSrcVType); std::vector<int64_t> row_vec(1, Bipartite::kSrcVType);
std::vector<int64_t> col_vec(1, Bipartite::kDstVType); std::vector<int64_t> col_vec(1, Bipartite::kDstVType);
...@@ -20,8 +23,9 @@ inline GraphPtr CreateBipartiteMetaGraph() { ...@@ -20,8 +23,9 @@ inline GraphPtr CreateBipartiteMetaGraph() {
GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col); GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
return g; return g;
} }
static const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph(); const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph();
} // namespace
}; // namespace
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
// //
...@@ -29,22 +33,20 @@ static const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph(); ...@@ -29,22 +33,20 @@ static const GraphPtr kBipartiteMetaGraph = CreateBipartiteMetaGraph();
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
/*! \brief COO graph */
class Bipartite::COO : public BaseHeteroGraph { class Bipartite::COO : public BaseHeteroGraph {
public: public:
COO(int64_t num_src, int64_t num_dst, COO(int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
IdArray src, IdArray dst)
: BaseHeteroGraph(kBipartiteMetaGraph) { : BaseHeteroGraph(kBipartiteMetaGraph) {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
COO(int64_t num_src, int64_t num_dst,
IdArray src, IdArray dst, bool is_multigraph) COO(int64_t num_src, int64_t num_dst, IdArray src, IdArray dst, bool is_multigraph)
: BaseHeteroGraph(kBipartiteMetaGraph), : BaseHeteroGraph(kBipartiteMetaGraph),
is_multigraph_(is_multigraph) { is_multigraph_(is_multigraph) {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
explicit COO(const aten::COOMatrix& coo)
: BaseHeteroGraph(kBipartiteMetaGraph), adj_(coo) {} explicit COO(const aten::COOMatrix& coo) : BaseHeteroGraph(kBipartiteMetaGraph), adj_(coo) {}
uint64_t NumVertexTypes() const override { uint64_t NumVertexTypes() const override {
return 2; return 2;
...@@ -155,7 +157,7 @@ class Bipartite::COO : public BaseHeteroGraph { ...@@ -155,7 +157,7 @@ class Bipartite::COO : public BaseHeteroGraph {
} }
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override { EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids), return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids), aten::IndexSelect(adj_.col, eids),
eids}; eids};
...@@ -288,7 +290,6 @@ class Bipartite::COO : public BaseHeteroGraph { ...@@ -288,7 +290,6 @@ class Bipartite::COO : public BaseHeteroGraph {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
/*! \brief CSR graph */ /*! \brief CSR graph */
class Bipartite::CSR : public BaseHeteroGraph { class Bipartite::CSR : public BaseHeteroGraph {
public: public:
...@@ -305,8 +306,7 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -305,8 +306,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
explicit CSR(const aten::CSRMatrix& csr) explicit CSR(const aten::CSRMatrix& csr) : BaseHeteroGraph(kBipartiteMetaGraph), adj_(csr) {}
: BaseHeteroGraph(kBipartiteMetaGraph), adj_(csr) {}
uint64_t NumVertexTypes() const override { uint64_t NumVertexTypes() const override {
return 2; return 2;
...@@ -345,6 +345,34 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -345,6 +345,34 @@ class Bipartite::CSR : public BaseHeteroGraph {
return adj_.indices->dtype.bits; return adj_.indices->dtype.bits;
} }
CSR AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
CSR ret(
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
CSR CopyTo(const DLContext& ctx) const {
if (Context() == ctx) {
return *this;
} else {
CSR ret(
adj_.num_rows, adj_.num_cols,
adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
bool IsMultigraph() const override { bool IsMultigraph() const override {
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () { return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
return aten::CSRHasDuplicate(adj_); return aten::CSRHasDuplicate(adj_);
...@@ -386,8 +414,8 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -386,8 +414,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
} }
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids); return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
} }
...@@ -408,8 +436,8 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -408,8 +436,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
} }
EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override { EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(IsValidIdArray(src)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst); const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]}; return EdgeArray{arrs[0], arrs[1], arrs[2]};
} }
...@@ -443,7 +471,7 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -443,7 +471,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
} }
EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override { EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids); auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false); auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so // Note that the row id in the csr submat is relabled, so
...@@ -476,7 +504,7 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -476,7 +504,7 @@ class Bipartite::CSR : public BaseHeteroGraph {
} }
DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override { DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::CSRGetRowNNZ(adj_, vids); return aten::CSRGetRowNNZ(adj_, vids);
} }
...@@ -518,8 +546,8 @@ class Bipartite::CSR : public BaseHeteroGraph { ...@@ -518,8 +546,8 @@ class Bipartite::CSR : public BaseHeteroGraph {
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), 2) << "Number of vertex types mismatch"; CHECK_EQ(vids.size(), 2) << "Number of vertex types mismatch";
CHECK(IsValidIdArray(vids[0])) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids[0])) << "Invalid vertex id array.";
CHECK(IsValidIdArray(vids[1])) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids[1])) << "Invalid vertex id array.";
HeteroSubgraph subg; HeteroSubgraph subg;
const auto& submat = aten::CSRSliceMatrix(adj_, vids[0], vids[1]); const auto& submat = aten::CSRSliceMatrix(adj_, vids[0], vids[1]);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context()); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
...@@ -579,7 +607,7 @@ bool Bipartite::HasVertex(dgl_type_t vtype, dgl_id_t vid) const { ...@@ -579,7 +607,7 @@ bool Bipartite::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
} }
BoolArray Bipartite::HasVertices(dgl_type_t vtype, IdArray vids) const { BoolArray Bipartite::HasVertices(dgl_type_t vtype, IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input"; CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices(vtype)); return aten::LT(vids, NumVertices(vtype));
} }
...@@ -761,6 +789,37 @@ HeteroGraphPtr Bipartite::CreateFromCSR( ...@@ -761,6 +789,37 @@ HeteroGraphPtr Bipartite::CreateFromCSR(
return HeteroGraphPtr(new Bipartite(nullptr, csr, nullptr)); return HeteroGraphPtr(new Bipartite(nullptr, csr, nullptr));
} }
HeteroGraphPtr Bipartite::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
} else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<Bipartite>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
return HeteroGraphPtr(new Bipartite(new_incsr, new_outcsr, nullptr));
}
}
HeteroGraphPtr Bipartite::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<Bipartite>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
return HeteroGraphPtr(new Bipartite(new_incsr, new_outcsr, nullptr));
}
Bipartite::Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo) Bipartite::Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: BaseHeteroGraph(kBipartiteMetaGraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { : BaseHeteroGraph(kBipartiteMetaGraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
...@@ -813,6 +872,18 @@ Bipartite::COOPtr Bipartite::GetCOO() const { ...@@ -813,6 +872,18 @@ Bipartite::COOPtr Bipartite::GetCOO() const {
return coo_; return coo_;
} }
aten::CSRMatrix Bipartite::GetInCSRMatrix() const {
return GetInCSR()->adj();
}
aten::CSRMatrix Bipartite::GetOutCSRMatrix() const {
return GetOutCSR()->adj();
}
aten::COOMatrix Bipartite::GetCOOMatrix() const {
return GetCOO()->adj();
}
HeteroGraphPtr Bipartite::GetAny() const { HeteroGraphPtr Bipartite::GetAny() const {
if (in_csr_) { if (in_csr_) {
return in_csr_; return in_csr_;
......
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
#ifndef DGL_GRAPH_BIPARTITE_H_ #ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_ #define DGL_GRAPH_BIPARTITE_H_
#include <dgl/graph_interface.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector> #include <dgl/lazy.h>
#include <string> #include <dgl/array.h>
#include <utility> #include <utility>
#include <memory> #include <string>
#include <vector>
#include "../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph { ...@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph {
/*! \brief edge group type */ /*! \brief edge group type */
static constexpr dgl_type_t kEType = 0; static constexpr dgl_type_t kEType = 0;
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override { uint64_t NumVertexTypes() const override {
return 2; return 2;
} }
...@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph { ...@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph {
int64_t num_src, int64_t num_dst, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids); IdArray indptr, IdArray indices, IdArray edge_ids);
private: /*! \brief Convert the graph to use the given number of bits for storage */
// internal data structure static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo); /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */ /*! \return Return the in-edge CSR format. Create from other format if not exist. */
CSRPtr GetInCSR() const; CSRPtr GetInCSR() const;
...@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph { ...@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph {
/*! \return Return the COO format. Create from other format if not exist. */ /*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const; COOPtr GetCOO() const;
/*! \return Return the in-edge CSR in the matrix form */
aten::CSRMatrix GetInCSRMatrix() const;
/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetOutCSRMatrix() const;
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const;
private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
......
...@@ -16,8 +16,8 @@ namespace dgl { ...@@ -16,8 +16,8 @@ namespace dgl {
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes, Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes,
bool multigraph): is_multigraph_(multigraph) { bool multigraph): is_multigraph_(multigraph) {
CHECK(IsValidIdArray(src_ids)); CHECK(aten::IsValidIdArray(src_ids));
CHECK(IsValidIdArray(dst_ids)); CHECK(aten::IsValidIdArray(dst_ids));
this->AddVertices(num_nodes); this->AddVertices(num_nodes);
num_edges_ = src_ids->shape[0]; num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0]) CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
...@@ -66,8 +66,8 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { ...@@ -66,8 +66,8 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) { void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data); const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
...@@ -92,7 +92,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) { ...@@ -92,7 +92,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
} }
BoolArray Graph::HasVertices(IdArray vids) const { BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx); BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
...@@ -113,8 +113,8 @@ bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const { ...@@ -113,8 +113,8 @@ bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
// O(E*k) pretty slow // O(E*k) pretty slow
BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const { BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen); const auto rstlen = std::max(srclen, dstlen);
...@@ -201,8 +201,8 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -201,8 +201,8 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
// O(E*k) pretty slow // O(E*k) pretty slow
EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
int64_t i, j; int64_t i, j;
...@@ -247,7 +247,7 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -247,7 +247,7 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
} }
EdgeArray Graph::FindEdges(IdArray eids) const { EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0]; int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx); IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
...@@ -291,7 +291,7 @@ EdgeArray Graph::InEdges(dgl_id_t vid) const { ...@@ -291,7 +291,7 @@ EdgeArray Graph::InEdges(dgl_id_t vid) const {
// O(E) // O(E)
EdgeArray Graph::InEdges(IdArray vids) const { EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0; int64_t rstlen = 0;
...@@ -337,7 +337,7 @@ EdgeArray Graph::OutEdges(dgl_id_t vid) const { ...@@ -337,7 +337,7 @@ EdgeArray Graph::OutEdges(dgl_id_t vid) const {
// O(E) // O(E)
EdgeArray Graph::OutEdges(IdArray vids) const { EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0; int64_t rstlen = 0;
...@@ -409,7 +409,7 @@ EdgeArray Graph::Edges(const std::string &order) const { ...@@ -409,7 +409,7 @@ EdgeArray Graph::Edges(const std::string &order) const {
// O(V) // O(V)
DegreeArray Graph::InDegrees(IdArray vids) const { DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx); DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
...@@ -424,7 +424,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const { ...@@ -424,7 +424,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const {
// O(V) // O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const { DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx); DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
...@@ -438,7 +438,7 @@ DegreeArray Graph::OutDegrees(IdArray vids) const { ...@@ -438,7 +438,7 @@ DegreeArray Graph::OutDegrees(IdArray vids) const {
} }
Subgraph Graph::VertexSubgraph(IdArray vids) const { Subgraph Graph::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv; std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> edges; std::vector<dgl_id_t> edges;
...@@ -468,7 +468,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -468,7 +468,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
} }
Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array."; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
const auto len = eids->shape[0]; const auto len = eids->shape[0];
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
const int64_t* eid_data = static_cast<int64_t*>(eids->data); const int64_t* eid_data = static_cast<int64_t*>(eids->data);
......
...@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes( ...@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
} }
IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
CHECK(IsValidIdArray(parent_vids)) << "Invalid parent id array."; CHECK(aten::IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(IsValidIdArray(query)) << "Invalid query id array."; CHECK(aten::IsValidIdArray(query)) << "Invalid query id array.";
const auto parent_len = parent_vids->shape[0]; const auto parent_len = parent_vids->shape[0];
const auto query_len = query->shape[0]; const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data); const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
......
...@@ -114,17 +114,33 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -114,17 +114,33 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
} }
// create num verts per type // create num verts per type
num_verts_per_type_.resize(meta_graph_->NumVertices(), -1); num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
for (dgl_type_t vtype = 0; vtype < meta_graph_->NumVertices(); ++vtype) {
for (dgl_type_t etype : meta_graph->OutEdgeVec(vtype)) { EdgeArray etype_array = meta_graph->Edges();
const auto nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType); dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
if (num_verts_per_type_[vtype] < 0) { dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
num_verts_per_type_[vtype] = nv; dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
} else {
CHECK_EQ(num_verts_per_type_[vtype], nv) for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
<< "Mismatch number of vertices for vertex type " << vtype; dgl_type_t srctype = srctypes[i];
} dgl_type_t dsttype = dsttypes[i];
} dgl_type_t etype = etypes[i];
size_t nv;
// # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType);
if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv;
else
CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType);
if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv;
else
CHECK_EQ(num_verts_per_type_[dsttype], nv)
<< "Mismatch number of vertices for vertex type " << dsttype;
} }
} }
...@@ -140,7 +156,7 @@ bool HeteroGraph::IsMultigraph() const { ...@@ -140,7 +156,7 @@ bool HeteroGraph::IsMultigraph() const {
} }
BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input"; CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices(vtype)); return aten::LT(vids, NumVertices(vtype));
} }
...@@ -192,7 +208,7 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -192,7 +208,7 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCOO") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t num_src = args[0];
int64_t num_dst = args[1]; int64_t num_dst = args[1];
...@@ -202,7 +218,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCOO") ...@@ -202,7 +218,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCSR") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t num_src = args[0];
int64_t num_dst = args[1]; int64_t num_dst = args[1];
...@@ -213,7 +229,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCSR") ...@@ -213,7 +229,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateHeteroGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1]; List<HeteroGraphRef> rel_graphs = args[1];
...@@ -226,20 +242,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateHeteroGraph") ...@@ -226,20 +242,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetMetaGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->meta_graph()); *rv = GraphRef(hg->meta_graph());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetRelationGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
*rv = HeteroGraphRef(hg->GetRelationGraph(etype)); *rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
...@@ -247,7 +263,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddVertices") ...@@ -247,7 +263,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddVertices")
hg->AddVertices(vtype, num); hg->AddVertices(vtype, num);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdge") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -256,7 +272,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdge") ...@@ -256,7 +272,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdge")
hg->AddEdge(etype, src, dst); hg->AddEdge(etype, src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -265,51 +281,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdges") ...@@ -265,51 +281,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroAddEdges")
hg->AddEdges(etype, src, dst); hg->AddEdges(etype, src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroClear") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
hg->Clear(); hg->Clear();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroContext") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->Context(); *rv = hg->Context();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumBits") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->NumBits(); *rv = hg->NumBits();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroIsMultigraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->IsMultigraph(); *rv = hg->IsMultigraph();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroIsReadonly") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->IsReadonly(); *rv = hg->IsReadonly();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
*rv = static_cast<int64_t>(hg->NumVertices(vtype)); *rv = static_cast<int64_t>(hg->NumVertices(vtype));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroNumEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
*rv = static_cast<int64_t>(hg->NumEdges(etype)); *rv = static_cast<int64_t>(hg->NumEdges(etype));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertex") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
...@@ -317,7 +333,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertex") ...@@ -317,7 +333,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertex")
*rv = hg->HasVertex(vtype, vid); *rv = hg->HasVertex(vtype, vid);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
...@@ -325,7 +341,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertices") ...@@ -325,7 +341,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasVertices")
*rv = hg->HasVertices(vtype, vids); *rv = hg->HasVertices(vtype, vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgeBetween") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -334,7 +350,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgeBetween") ...@@ -334,7 +350,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgeBetween")
*rv = hg->HasEdgeBetween(etype, src, dst); *rv = hg->HasEdgeBetween(etype, src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgesBetween") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -343,7 +359,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgesBetween") ...@@ -343,7 +359,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroHasEdgesBetween")
*rv = hg->HasEdgesBetween(etype, src, dst); *rv = hg->HasEdgesBetween(etype, src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroPredecessors") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -351,7 +367,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroPredecessors") ...@@ -351,7 +367,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroPredecessors")
*rv = hg->Predecessors(etype, dst); *rv = hg->Predecessors(etype, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSuccessors") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -359,7 +375,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSuccessors") ...@@ -359,7 +375,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSuccessors")
*rv = hg->Successors(etype, src); *rv = hg->Successors(etype, src);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeId") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -368,7 +384,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeId") ...@@ -368,7 +384,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeId")
*rv = hg->EdgeId(etype, src, dst); *rv = hg->EdgeId(etype, src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeIds") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -378,7 +394,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeIds") ...@@ -378,7 +394,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeIds")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroFindEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -387,7 +403,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroFindEdges") ...@@ -387,7 +403,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroFindEdges")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_1") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -396,7 +412,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_1") ...@@ -396,7 +412,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_1")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_2") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -405,7 +421,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_2") ...@@ -405,7 +421,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInEdges_2")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_1") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -414,7 +430,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_1") ...@@ -414,7 +430,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_1")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_2") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -423,7 +439,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_2") ...@@ -423,7 +439,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutEdges_2")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -432,7 +448,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdges") ...@@ -432,7 +448,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdges")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegree") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -440,7 +456,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegree") ...@@ -440,7 +456,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegree")
*rv = static_cast<int64_t>(hg->InDegree(etype, vid)); *rv = static_cast<int64_t>(hg->InDegree(etype, vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegrees") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -448,7 +464,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegrees") ...@@ -448,7 +464,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroInDegrees")
*rv = hg->InDegrees(etype, vids); *rv = hg->InDegrees(etype, vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegree") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -456,7 +472,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegree") ...@@ -456,7 +472,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegree")
*rv = static_cast<int64_t>(hg->OutDegree(etype, vid)); *rv = static_cast<int64_t>(hg->OutDegree(etype, vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegrees") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -464,7 +480,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegrees") ...@@ -464,7 +480,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroOutDegrees")
*rv = hg->OutDegrees(etype, vids); *rv = hg->OutDegrees(etype, vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetAdj") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
...@@ -474,7 +490,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetAdj") ...@@ -474,7 +490,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroGetAdj")
hg->GetAdj(etype, transpose, fmt)); hg->GetAdj(etype, transpose, fmt));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroVertexSubgraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> vids = args[1]; List<Value> vids = args[1];
...@@ -488,7 +504,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroVertexSubgraph") ...@@ -488,7 +504,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroVertexSubgraph")
*rv = HeteroSubgraphRef(subg); *rv = HeteroSubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeSubgraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> eids = args[1]; List<Value> eids = args[1];
...@@ -505,13 +521,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeSubgraph") ...@@ -505,13 +521,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroEdgeSubgraph")
// HeteroSubgraph C APIs // HeteroSubgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
*rv = HeteroGraphRef(subg->graph); *rv = HeteroGraphRef(subg->graph);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
List<Value> induced_verts; List<Value> induced_verts;
...@@ -521,7 +537,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedVertices") ...@@ -521,7 +537,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
*rv = induced_verts; *rv = induced_verts;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
List<Value> induced_edges; List<Value> induced_edges;
...@@ -531,4 +547,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedEdges") ...@@ -531,4 +547,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
*rv = induced_edges; *rv = induced_edges;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int device_type = args[1];
int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
} // namespace dgl } // namespace dgl
...@@ -59,9 +59,9 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph) ...@@ -59,9 +59,9 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
CHECK(IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1; const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
...@@ -69,9 +69,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { ...@@ -69,9 +69,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: is_multigraph_(is_multigraph) { : is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1; const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
...@@ -79,9 +79,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) ...@@ -79,9 +79,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) { const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
...@@ -98,9 +98,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -98,9 +98,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name): is_multigraph_(is_multigraph), const std::string &shared_mem_name): is_multigraph_(is_multigraph),
shared_mem_name_(shared_mem_name) { shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
...@@ -140,7 +140,7 @@ EdgeArray CSR::OutEdges(dgl_id_t vid) const { ...@@ -140,7 +140,7 @@ EdgeArray CSR::OutEdges(dgl_id_t vid) const {
} }
EdgeArray CSR::OutEdges(IdArray vids) const { EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids); auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false); auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so // Note that the row id in the csr submat is relabled, so
...@@ -150,7 +150,7 @@ EdgeArray CSR::OutEdges(IdArray vids) const { ...@@ -150,7 +150,7 @@ EdgeArray CSR::OutEdges(IdArray vids) const {
} }
DegreeArray CSR::OutDegrees(IdArray vids) const { DegreeArray CSR::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::CSRGetRowNNZ(adj_, vids); return aten::CSRGetRowNNZ(adj_, vids);
} }
...@@ -161,8 +161,8 @@ bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const { ...@@ -161,8 +161,8 @@ bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
} }
BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const { BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids); return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
} }
...@@ -192,7 +192,7 @@ EdgeArray CSR::Edges(const std::string &order) const { ...@@ -192,7 +192,7 @@ EdgeArray CSR::Edges(const std::string &order) const {
} }
Subgraph CSR::VertexSubgraph(IdArray vids) const { Subgraph CSR::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids); const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context()); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids)); CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
...@@ -272,16 +272,16 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const { ...@@ -272,16 +272,16 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) { COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
CHECK(IsValidIdArray(src)); CHECK(aten::IsValidIdArray(src));
CHECK(IsValidIdArray(dst)); CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]); CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst}; adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
} }
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph) COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
: is_multigraph_(is_multigraph) { : is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(src)); CHECK(aten::IsValidIdArray(src));
CHECK(IsValidIdArray(dst)); CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]); CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst}; adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
} }
...@@ -301,7 +301,7 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const { ...@@ -301,7 +301,7 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
} }
EdgeArray COO::FindEdges(IdArray eids) const { EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids), return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids), aten::IndexSelect(adj_.col, eids),
eids}; eids};
...@@ -316,7 +316,7 @@ EdgeArray COO::Edges(const std::string &order) const { ...@@ -316,7 +316,7 @@ EdgeArray COO::Edges(const std::string &order) const {
} }
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array."; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
COOPtr subcoo; COOPtr subcoo;
IdArray induced_nodes; IdArray induced_nodes;
if (!preserve_nodes) { if (!preserve_nodes) {
...@@ -379,7 +379,7 @@ COO COO::AsNumBits(uint8_t bits) const { ...@@ -379,7 +379,7 @@ COO COO::AsNumBits(uint8_t bits) const {
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
BoolArray ImmutableGraph::HasVertices(IdArray vids) const { BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input"; CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices()); return aten::LT(vids, NumVertices());
} }
......
...@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr, ...@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const bool add_self_loop, const bool add_self_loop,
const ValueType *probability) { const ValueType *probability) {
// process args // process args
CHECK(IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(max_num_workers,
...@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(max_num_workers,
...@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_edges)); CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr); BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0]; const int64_t num_seeds = seed_edges->shape[0];
......
...@@ -197,7 +197,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -197,7 +197,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
...@@ -219,7 +219,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") ...@@ -219,7 +219,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
const bool return_labels = args[5]; const bool return_labels = args[5];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "./binary_reduce_impl_decl.h" #include "./binary_reduce_impl_decl.h"
#include "./utils.h" #include "./utils.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../graph/bipartite.h"
#include "./csr_interface.h" #include "./csr_interface.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -228,6 +229,31 @@ class ImmutableGraphCSRWrapper : public CSRWrapper { ...@@ -228,6 +229,31 @@ class ImmutableGraphCSRWrapper : public CSRWrapper {
const ImmutableGraph* gptr_; const ImmutableGraph* gptr_;
}; };
class BipartiteCSRWrapper : public CSRWrapper {
public:
explicit BipartiteCSRWrapper(const Bipartite* graph) :
gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override {
return gptr_->GetInCSRMatrix();
}
aten::CSRMatrix GetOutCSRMatrix() const override {
return gptr_->GetOutCSRMatrix();
}
DGLContext Context() const override {
return gptr_->Context();
}
int NumBits() const override {
return gptr_->NumBits();
}
private:
const Bipartite* gptr_;
};
} // namespace } // namespace
...@@ -293,11 +319,32 @@ void BinaryOpReduce( ...@@ -293,11 +319,32 @@ void BinaryOpReduce(
} }
} }
// Comes from DGLArgValue::AsObjectRef() that allows argvalue to be either a GraphRef
// or a HeteroGraphRef
#define CSRWRAPPER_SWITCH(argvalue, wrapper, ...) do { \
DGLArgValue argval = (argvalue); \
DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle); \
std::shared_ptr<Object>& sptr = \
*argval.ptr<std::shared_ptr<Object>>(); \
if (ObjectTypeChecker<GraphRef>::Check(sptr.get())) { \
GraphRef g = argval; \
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); \
CHECK_NOTNULL(igptr); \
ImmutableGraphCSRWrapper wrapper(igptr.get()); \
{__VA_ARGS__} \
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \
HeteroGraphRef g = argval; \
auto bgptr = std::dynamic_pointer_cast<Bipartite>(g.sptr()); \
CHECK_NOTNULL(bgptr); \
BipartiteCSRWrapper wrapper(bgptr.get()); \
{__VA_ARGS__} \
} \
} while (0)
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_data = args[5]; NDArray lhs_data = args[5];
...@@ -307,14 +354,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -307,14 +354,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9]; NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10]; NDArray out_mapping = args[10];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); CSRWRAPPER_SWITCH(args[2], wrapper, {
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BinaryOpReduce(reducer, op, wrapper, BinaryOpReduce(reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping); lhs_mapping, rhs_mapping, out_mapping);
}); });
});
void BackwardLhsBinaryOpReduce( void BackwardLhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
...@@ -370,7 +416,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -370,7 +416,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_mapping = args[5]; NDArray lhs_mapping = args[5];
...@@ -382,9 +427,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -382,9 +427,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12]; NDArray grad_lhs_data = args[12];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); CSRWRAPPER_SWITCH(args[2], wrapper, {
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardLhsBinaryOpReduce( BackwardLhsBinaryOpReduce(
reducer, op, wrapper, reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
...@@ -392,6 +435,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -392,6 +435,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data); grad_lhs_data);
}); });
});
void BackwardRhsBinaryOpReduce( void BackwardRhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
...@@ -446,7 +490,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -446,7 +490,6 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_mapping = args[5]; NDArray lhs_mapping = args[5];
...@@ -458,9 +501,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -458,9 +501,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12]; NDArray grad_rhs_data = args[12];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); CSRWRAPPER_SWITCH(args[2], wrapper, {
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardRhsBinaryOpReduce( BackwardRhsBinaryOpReduce(
reducer, op, wrapper, reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
...@@ -468,6 +509,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -468,6 +509,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
grad_rhs_data); grad_rhs_data);
}); });
});
void CopyReduce( void CopyReduce(
const std::string& reducer, const std::string& reducer,
...@@ -493,21 +535,19 @@ void CopyReduce( ...@@ -493,21 +535,19 @@ void CopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
GraphRef g = args[1];
int target = args[2]; int target = args[2];
NDArray in_data = args[3]; NDArray in_data = args[3];
NDArray out_data = args[4]; NDArray out_data = args[4];
NDArray in_mapping = args[5]; NDArray in_mapping = args[5];
NDArray out_mapping = args[6]; NDArray out_mapping = args[6];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); CSRWRAPPER_SWITCH(args[1], wrapper, {
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
CopyReduce(reducer, wrapper, CopyReduce(reducer, wrapper,
static_cast<binary_op::Target>(target), static_cast<binary_op::Target>(target),
in_data, out_data, in_data, out_data,
in_mapping, out_mapping); in_mapping, out_mapping);
}); });
});
void BackwardCopyReduce( void BackwardCopyReduce(
const std::string& reducer, const std::string& reducer,
...@@ -542,7 +582,6 @@ void BackwardCopyReduce( ...@@ -542,7 +582,6 @@ void BackwardCopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
GraphRef g = args[1];
int target = args[2]; int target = args[2];
NDArray in_data = args[3]; NDArray in_data = args[3];
NDArray out_data = args[4]; NDArray out_data = args[4];
...@@ -551,15 +590,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") ...@@ -551,15 +590,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7]; NDArray in_mapping = args[7];
NDArray out_mapping = args[8]; NDArray out_mapping = args[8];
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); CSRWRAPPER_SWITCH(args[1], wrapper, {
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardCopyReduce( BackwardCopyReduce(
reducer, wrapper, static_cast<binary_op::Target>(target), reducer, wrapper, static_cast<binary_op::Target>(target),
in_mapping, out_mapping, in_mapping, out_mapping,
in_data, out_data, grad_out_data, in_data, out_data, grad_out_data,
grad_in_data); grad_in_data);
}); });
});
} // namespace kernel } // namespace kernel
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment