Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
...@@ -21,7 +21,9 @@ namespace dgl { ...@@ -21,7 +21,9 @@ namespace dgl {
// Forward declaration // Forward declaration
class BaseHeteroGraph; class BaseHeteroGraph;
class FlattenedHeteroGraph;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr; typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph; struct HeteroSubgraph;
/*! /*!
...@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////////// query/operations on meta graph //////////////////////// ////////////////////////// query/operations on meta graph ////////////////////////
/*! \return the number of vertex types */ /*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const = 0; virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}
/*! \return the number of edge types */ /*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const = 0; virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}
/*! \return the meta graph */ /*! \return the meta graph */
virtual GraphPtr meta_graph() const { virtual GraphPtr meta_graph() const {
...@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph EdgeSubgraph( virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0; const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}
static constexpr const char* _type_key = "graph.HeteroGraph"; static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
...@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object { ...@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);
}; };
/*! \brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object {
/*! \brief The graph */
HeteroGraphRef graph;
/*!
* \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_srctype;
/*!
* \brief The set of node types in parent graph appearing in source nodes.
*/
IdArray induced_srctype_set;
/*! \brief Mapping from source node ID to local node ID in parent graph */
IdArray induced_srcid;
/*!
* \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_etype;
/*!
* \brief The set of edge types in parent graph appearing in edges.
*/
IdArray induced_etype_set;
/*! \brief Mapping from edge ID to local edge ID in parent graph */
IdArray induced_eid;
/*!
* \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_dsttype;
/*!
* \brief The set of node types in parent graph appearing in destination nodes.
*/
IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */
IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set);
v->Visit("induced_srcid", &induced_srcid);
v->Visit("induced_etype", &induced_etype);
v->Visit("induced_etype_set", &induced_etype_set);
v->Visit("induced_eid", &induced_eid);
v->Visit("induced_dsttype", &induced_dsttype);
v->Visit("induced_dsttype_set", &induced_dsttype_set);
v->Visit("induced_dstid", &induced_dstid);
}
static constexpr const char* _type_key = "graph.FlattenedHeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(FlattenedHeteroGraph, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Define HeteroSubgraphRef // Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph); DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
......
...@@ -18,6 +18,7 @@ namespace runtime { ...@@ -18,6 +18,7 @@ namespace runtime {
// forward declaration // forward declaration
class Object; class Object;
class ObjectRef; class ObjectRef;
class NDArray;
/*! /*!
* \brief Visitor class to each object attribute. * \brief Visitor class to each object attribute.
...@@ -33,6 +34,7 @@ class AttrVisitor { ...@@ -33,6 +34,7 @@ class AttrVisitor {
virtual void Visit(const char* key, bool* value) = 0; virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0; virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0; virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum, template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) { void Visit(const char* key, ENum* ptr) {
......
...@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode ...@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
from ._ffi.base import DGLError, __version__ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL, NTYPE, NID, ETYPE, EID
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .convert import *
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
......
...@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api ...@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api
# A special symbol for selecting all nodes or edges. # A special symbol for selecting all nodes or edges.
ALL = "__ALL__" ALL = "__ALL__"
# An alias for [:]
SLICE_FULL = slice(None, None, None)
# Reserved column names for storing parent node/edge types and IDs in flattened heterographs
NTYPE = '_TYPE'
NID = '_ID'
ETYPE = '_TYPE'
EID = '_ID'
def is_all(arg): def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
......
"""Module for converting graph from/to other object."""
from collections import defaultdict
import numpy as np
import scipy as sp
import networkx as nx
from . import backend as F
from . import heterograph_index
from .heterograph import DGLHeteroGraph, combine_frames
from . import graph_index
from . import utils
from .base import NTYPE, ETYPE, NID, EID
__all__ = [
'graph',
'bipartite',
'hetero_from_relations',
'to_hetero',
'to_homo',
'to_networkx',
]
def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
"""Create a graph.
The graph has only one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
whose adjacency matrix must be square while :func:`dgl.bipartite`
creates a graph that does not necessarily have square adjacency matrix.
Examples
--------
Create from edges pairs:
>>> g = dgl.graph([(0, 2), (0, 3), (1, 2)])
Creat from pair of vertex IDs lists
>>> u = [0, 0, 1]
>>> v = [2, 3, 2]
>>> g = dgl.graph((u, v))
The IDs can also be stored in framework-specific tensors
>>> import torch
>>> u = torch.tensor([0, 0, 1])
>>> v = torch.tensor([2, 3, 2])
>>> g = dgl.graph((u, v))
Create from scipy sparse matrix
>>> from scipy.sparse import coo_matrix
>>> spmat = coo_matrix(([1,1,1], ([0, 0, 1], [2, 3, 2])), shape=(4, 4))
>>> g = dgl.graph(spmat)
Create from networkx graph
>>> import networkx as nx
>>> nxg = nx.path_graph(3)
>>> g = dgl.graph(nxg)
Specify node and edge type names
>>> g = dgl.graph(..., 'user', 'follows')
>>> g.ntypes
['user']
>>> g.etypes
['follows']
>>> g.canonical_etypes
[('user', 'follows', 'user')]
Parameters
----------
data : graph data
Data to initialize graph structure. Supported data formats are
(1) list of edge pairs (e.g. [(0, 2), (3, 1), ...])
(2) pair of vertex IDs representing end nodes (e.g. ([0, 3, ...], [2, 1, ...]))
(3) scipy sparse matrix
(4) networkx graph
ntype : str, optional
Node type name. (Default: _N)
etype : str, optional
Edge type name. (Default: _E)
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data.
(Default: None)
kwargs : key-word arguments, optional
Other key word arguments.
Returns
-------
DGLHeteroGraph
"""
if card is not None:
urange, vrange = card, card
else:
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange)
elif isinstance(data, list):
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, ntype, etype, ntype)
elif isinstance(data, nx.Graph):
return create_from_networkx(data, ntype, etype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
to ``vtype`` nodes. Nodes of each type have their own ID counts.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
whose adjacency matrix must be square while :func:`dgl.bipartite`
creates a graph that does not necessarily have square adjacency matrix.
Examples
--------
Create from edges pairs:
>>> g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game')
>>> g.ntypes
['user', 'game']
>>> g.etypes
['plays']
>>> g.canonical_etypes
[('user', 'plays', 'game')]
>>> g.number_of_nodes('user')
2
>>> g.number_of_nodes('game')
4
>>> g.number_of_edges('plays') # 'plays' could be omitted here
3
Creat from pair of vertex IDs lists
>>> u = [0, 0, 1]
>>> v = [2, 3, 2]
>>> g = dgl.bipartite((u, v))
The IDs can also be stored in framework-specific tensors
>>> import torch
>>> u = torch.tensor([0, 0, 1])
>>> v = torch.tensor([2, 3, 2])
>>> g = dgl.bipartite((u, v))
Create from scipy sparse matrix. Since scipy sparse matrix has explicit
shape, the cardinality of the result graph is derived from that.
>>> from scipy.sparse import coo_matrix
>>> spmat = coo_matrix(([1,1,1], ([0, 0, 1], [2, 3, 2])), shape=(4, 4))
>>> g = dgl.bipartite(spmat, 'user', 'plays', 'game')
>>> g.number_of_nodes('user')
4
>>> g.number_of_nodes('game')
4
Create from networkx graph. The given graph must follow the bipartite
graph convention in networkx. Each node has a ``bipartite`` attribute
with values 0 or 1. The result graph has two types of nodes and only
edges from ``bipartite=0`` to ``bipartite=1`` will be included.
>>> import networkx as nx
>>> nxg = nx.complete_bipartite_graph(3, 4)
>>> g = dgl.graph(nxg, 'user', 'plays', 'game')
>>> g.number_of_nodes('user')
3
>>> g.number_of_nodes('game')
4
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Parameters
----------
data : graph data
Data to initialize graph structure. Supported data formats are
(1) list of edge pairs (e.g. [(0, 2), (3, 1), ...])
(2) pair of vertex IDs representing end nodes (e.g. ([0, 3, ...], [2, 1, ...]))
(3) scipy sparse matrix
(4) networkx graph
utype : str, optional
Source node type name. (Default: _U)
etype : str, optional
Edge type name. (Default: _E)
vtype : str, optional
Destination node type name. (Default: _V)
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
infer from input data. (Default: None)
kwargs : key-word arguments, optional
Other key word arguments.
Returns
-------
DGLHeteroGraph
"""
if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
if card is not None:
urange, vrange = card
else:
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
elif isinstance(data, list):
return create_from_edge_list(data, utype, etype, vtype, urange, vrange)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, utype, etype, vtype)
elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite(data, utype, etype, vtype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))
def hetero_from_relations(rel_graphs):
"""Create a heterograph from per-relation graphs.
TODO(minjie): this API can be generalized as a union operation of
the input graphs
TODO(minjie): handle node/edge data
Parameters
----------
rel_graphs : list of DGLHeteroGraph
Graph for each relation.
Returns
-------
DGLHeteroGraph
A heterograph.
"""
# infer meta graph
ntype_dict = {} # ntype -> ntid
meta_edges = []
ntypes = []
etypes = []
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
if not stype in ntype_dict:
ntype_dict[stype] = len(ntypes)
ntypes.append(stype)
stid = ntype_dict[stype]
if not dtype in ntype_dict:
ntype_dict[dtype] = len(ntypes)
ntypes.append(dtype)
dtid = ntype_dict[dtype]
meta_edges.append((stid, dtid))
etypes.append(etype)
metagraph = graph_index.from_edge_list(meta_edges, True, True)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs])
retg = DGLHeteroGraph(hgidx, ntypes, etypes)
for i, rgrh in enumerate(rel_graphs):
for ntype in rgrh.ntypes:
retg.nodes[ntype].data.update(rgrh.nodes[ntype].data)
retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given graph to a heterogeneous graph.
The input graph should have only one type of nodes and edges. Each node and edge
stores an integer feature (under ``ntype_field`` and ``etype_field``), representing
the type id, which which can be used to retrieve the type names stored
in the given ``ntypes`` and ``etypes`` arguments.
Examples
--------
TBD
Parameters
----------
G : DGLHeteroGraph
Input homogenous graph.
ntypes : list of str
The node type names.
etypes : list of str
The edge type names.
ntype_field : str, optional
The feature field used to store node type. (Default: dgl.NTYPE)
etype_field : str, optional
The feature field used to store edge type. (Default: dgl.ETYPE)
metagraph : networkx MultiDiGraph, optional
Metagraph of the returned heterograph.
If provided, DGL assumes that G can indeed be described with the given metagraph.
If None, DGL will infer the metagraph from the given inputs, which would be
potentially slower for large graphs.
Returns
-------
DGLHeteroGraph
A heterograph.
The parent node and edge ID are stored in the column dgl.NID and dgl.EID
respectively for all node/edge types.
Notes
-----
The returned node and edge types may not necessarily be in the same order as
``ntypes`` and ``etypes``. And edge types may be duplicated if the source
and destination types differ.
The node IDs of a single type in the returned heterogeneous graph is ordered
the same as the nodes with the same ``ntype_field`` feature. Edge IDs of
a single type is similar.
"""
# TODO(minjie): use hasattr to support DGLGraph input; should be fixed once
# DGLGraph is merged with DGLHeteroGraph
if (hasattr(G, 'ntypes') and len(G.ntypes) > 1
or hasattr(G, 'etypes') and len(G.etypes) > 1):
raise DGLError('The input graph should be homogenous and have only one '
' type of nodes and edges.')
num_ntypes = len(ntypes)
ntype_ids = F.asnumpy(G.ndata[ntype_field])
etype_ids = F.asnumpy(G.edata[etype_field])
# relabel nodes to per-type local IDs
ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)
ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0)
ntype_ids_sortidx = np.argsort(ntype_ids)
ntype_local_ids = np.zeros_like(ntype_ids)
node_groups = []
for i in range(num_ntypes):
node_group = ntype_ids_sortidx[ntype_offset[i]:ntype_offset[i+1]]
node_groups.append(node_group)
ntype_local_ids[node_group] = np.arange(ntype_count[i])
src, dst = G.all_edges(order='eid')
src = F.asnumpy(src)
dst = F.asnumpy(dst)
src_local = ntype_local_ids[src]
dst_local = ntype_local_ids[dst]
srctype_ids = ntype_ids[src]
dsttype_ids = ntype_ids[dst]
canon_etype_ids = np.stack([srctype_ids, etype_ids, dsttype_ids], 1)
# infer metagraph
if metagraph is None:
canonical_etids, _, etype_remapped = \
utils.make_invmap(list(tuple(_) for _ in canon_etype_ids), False)
etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None])
else:
ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}
etypes_invmap = {et: i for i, et in enumerate(etypes)}
canonical_etids = []
etype_remapped = np.zeros(etype_ids)
for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)):
srctype_id = ntypes_invmap[srctype]
etype_id = etypes_invmap[etype]
dsttype_id = ntypes_invmap[dsttype]
canonical_etids.append((srctype_id, etype_id, dsttype_id))
canonical_etids = np.array(canonical_etids)
etype_mask = (canon_etype_ids[None, :] == canonical_etids[:, None]).all(2)
edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))]
rel_graphs = []
for i, (stid, etid, dtid) in enumerate(canonical_etids):
src_of_etype = src_local[edge_groups[i]]
dst_of_etype = dst_local[edge_groups[i]]
if stid == dtid:
rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
card=ntype_count[stid])
else:
rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
card=(ntype_count[stid], ntype_count[dtid]))
rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs)
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.tensor(ntype2ngrp[ntype])
for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[etid][EID] = F.tensor(edge_groups[etid])
# features
for key, data in G.ndata.items():
for ntid, ntype in enumerate(hg.ntypes):
rows = F.copy_to(F.tensor(ntype2ngrp[ntype]), F.context(data))
hg._node_frames[ntid][key] = F.gather_row(data, rows)
for key, data in G.edata.items():
for etid in range(len(hg.canonical_etypes)):
rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))
hg._edge_frames[etid][key] = F.gather_row(data, rows)
return hg
def to_homo(G):
"""Convert the given graph to a homogeneous graph.
The returned graph has only one type of nodes and etypes.
Node and edge types are stored as features in the returned graph. Each feature
is an integer representing the type id, which can be used to retrieve the type
names stored in ``G.ntypes`` and ``G.etypes`` arguments.
Examples
--------
TBD
Parameters
----------
G : DGLHeteroGraph
Input heterogenous graph.
Returns
-------
DGLHeteroGraph
A homogenous graph.
The parent node and edge type/ID are stored in columns dgl.NTYPE/dgl.NID and
dgl.ETYPE/dgl.EID respectively.
"""
num_nodes_per_ntype = [G.number_of_nodes(ntype) for ntype in G.ntypes]
offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)
srcs = []
dsts = []
etype_ids = []
eids = []
ntype_ids = []
nids = []
for ntype_id, ntype in enumerate(G.ntypes):
num_nodes = G.number_of_nodes(ntype)
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes))
for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype
src, dst = G.all_edges(etype=etype, order='eid')
num_edges = len(src)
srcs.append(src + offset_per_ntype[G.get_ntype_id(srctype)])
dsts.append(dst + offset_per_ntype[G.get_ntype_id(dsttype)])
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)))
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
retg.edata[EID] = F.cat(eids, 0)
# features
comb_nf = combine_frames(G._node_frames, range(len(G.ntypes)))
comb_ef = combine_frames(G._edge_frames, range(len(G.etypes)))
if comb_nf is not None:
retg.ndata.update(comb_nf)
if comb_ef is not None:
retg.edata.update(comb_ef)
return retg
############################################################
# Internal APIs
############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
Parameters
----------
u : iterable of int
List of source node IDs.
v : iterable of int
List of destination node IDs.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
urange : int, optional
The source node ID range. If None, the value is the maximum
of the source node IDs in the edge list plus 1. (Default: None)
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
Returns
------
DGLHeteroGraph
"""
u = utils.toindex(u)
v = utils.toindex(v)
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
if utype == vtype:
urange = vrange = max(urange, vrange)
num_ntypes = 1
else:
num_ntypes = 2
hgidx = heterograph_index.create_unitgraph_from_coo(num_ntypes, urange, vrange, u, v)
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
"""Internal function to create a graph from a list of edge tuples with types.
utype could be equal to vtype
Examples
--------
TBD
Parameters
----------
elist : iterable of int pairs
List of (src, dst) node ID pairs.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
urange : int, optional
The source node ID range. If None, the value is the maximum
of the source node IDs in the edge list plus 1. (Default: None)
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
Returns
-------
DGLHeteroGraph
"""
if len(elist) == 0:
u, v = [], []
else:
u, v = zip(*elist)
u = list(u)
v = list(v)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
"""Internal function to create a graph from a scipy sparse matrix with types.
Parameters
----------
spmat : scipy.sparse.spmatrix
The adjacency matrix whose rows represent sources and columns
represent destinations.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
Returns
-------
DGLHeteroGraph
"""
num_src, num_dst = spmat.shape
num_ntypes = 1 if utype == vtype else 2
if spmat.getformat() == 'coo':
row = utils.toindex(spmat.row)
col = utils.toindex(spmat.col)
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, num_src, num_dst, row, col)
else:
spmat = spmat.tocsr()
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
# TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix?
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, num_src, num_dst, indptr, indices, data)
if num_ntypes == 1:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_networkx(nx_graph,
ntype, etype,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Create graph that has only one set of nodes and edges.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
# Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
# nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
src = utils.toindex(src)
dst = utils.toindex(dst)
num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes)
# handle features
# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else:
return F.tensor(lst)
if node_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(list)
for nid in range(g.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
g.ndata[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(lambda: [None] * g.number_of_edges())
# each defaultdict value is initialized to be a list of None
# None here serves as placeholder to be replaced by feature with
# corresponding edge id
if has_edge_id:
num_edges = g.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True):
if attrs[edge_id_attr_name] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be'
' smaller than the number of edges --'
' {}, got {}.'.format(num_edges, attrs['id']))
for key in edge_attrs:
attr_dict[key][attrs['id']] = attrs[key]
else:
# XXX: assuming networkx iteration order is deterministic
# so the order is the same as graph_index.from_networkx
for eid, (_, _, attrs) in enumerate(nx_graph.edges(data=True)):
for key in edge_attrs:
attr_dict[key][eid] = attrs[key]
for attr in edge_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
g.edata[attr] = _batcher(attr_dict[attr])
return g
def create_from_networkx_bipartite(nx_graph,
utype, etype, vtype,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Create graph that has only one set of nodes and edges.
The input graph must follow the bipartite graph convention of networkx.
Each node has an attribute ``bipartite`` with values 0 and 1 indicating which
set it belongs to.
Only edges from node set 0 to node set 1 are added to the returned graph.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
top_nodes = {n for n, d in nx_graph.nodes(data=True) if d['bipartite'] == 0}
bottom_nodes = set(nx_graph) - top_nodes
top_nodes = sorted(top_nodes)
bottom_nodes = sorted(bottom_nodes)
top_map = {n : i for i, n in enumerate(top_nodes)}
bottom_map = {n : i for i, n in enumerate(bottom_nodes)}
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = top_map[u]
dst[eid] = bottom_map[v]
else:
src = []
dst = []
for e in nx_graph.edges:
if e[0] in top_map:
src.append(top_map[e[0]])
dst.append(bottom_map[e[1]])
src = utils.toindex(src)
dst = utils.toindex(dst)
g = create_from_edges(src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes))
# TODO attributes
assert node_attrs is None
assert edge_attrs is None
return g
def to_networkx(g, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
See Also
--------
DGLHeteroGraph.to_networkx
"""
return g.to_networkx(node_attrs, edge_attrs)
...@@ -186,8 +186,8 @@ class Frame(MutableMapping): ...@@ -186,8 +186,8 @@ class Frame(MutableMapping):
update on one will not reflect to the other. The inplace update will update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container. be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0] num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows`` The number of rows in this frame. If ``data`` is provided and is not empty,
will be ignored and inferred from the given data. ``num_rows`` will be ignored and inferred from the given data.
""" """
def __init__(self, data=None, num_rows=0): def __init__(self, data=None, num_rows=0):
if data is None: if data is None:
...@@ -202,7 +202,7 @@ class Frame(MutableMapping): ...@@ -202,7 +202,7 @@ class Frame(MutableMapping):
elif len(self._columns) != 0: elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values()))) self._num_rows = len(next(iter(self._columns.values())))
else: else:
self._num_rows = 0 self._num_rows = num_rows
# sanity check # sanity check
for name, col in self._columns.items(): for name, col in self._columns.items():
if len(col) != self._num_rows: if len(col) != self._num_rows:
...@@ -880,23 +880,23 @@ class FrameRef(MutableMapping): ...@@ -880,23 +880,23 @@ class FrameRef(MutableMapping):
""" """
return self._index.get_items(query) return self._index.get_items(query)
def frame_like(other, num_rows): def frame_like(other, num_rows=None):
"""Create a new frame that has the same scheme as the given one. """Create an empty frame that has the same initializer as the given one.
Parameters Parameters
---------- ----------
other : Frame other : Frame
The given frame. The given frame.
num_rows : int num_rows : int
The number of rows of the new one. The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns Returns
------- -------
Frame Frame
The new frame. The new frame.
""" """
# TODO(minjie): scheme is not inherited at the moment. Fix this num_rows = other.num_rows if num_rows is None else num_rows
# when moving per-col initializer to column scheme.
newf = Frame(num_rows=num_rows) newf = Frame(num_rows=num_rows)
# set global initializr # set global initializr
if other.get_initializer() is None: if other.get_initializer() is None:
......
...@@ -11,7 +11,7 @@ from . import backend as F ...@@ -11,7 +11,7 @@ from . import backend as F
from . import init from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index from . import graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime, GraphAdapter
from . import utils from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
...@@ -49,14 +49,6 @@ class DGLBaseGraph(object): ...@@ -49,14 +49,6 @@ class DGLBaseGraph(object):
""" """
return self._graph.number_of_nodes() return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self): def __len__(self):
"""Return the number of nodes in the graph.""" """Return the number of nodes in the graph."""
return self.number_of_nodes() return self.number_of_nodes()
...@@ -73,10 +65,6 @@ class DGLBaseGraph(object): ...@@ -73,10 +65,6 @@ class DGLBaseGraph(object):
""" """
return self._graph.is_readonly() return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges in the graph. """Return the number of edges in the graph.
...@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index): def _set_msg_index(self, index):
self._msg_index = index self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
...@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph):
else: else:
v = utils.toindex(v) v = utils.toindex(v)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self, scheduler.schedule_apply_nodes(v=v,
v=v,
apply_func=func, apply_func=func,
node_frame=self._node_frame,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
...@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid) u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self, scheduler.schedule_apply_edges(AdaptedDGLGraph(self), u, v, eid, func, inplace)
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False): def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
...@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid) u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_group_apply_edge(graph=self, scheduler.schedule_group_apply_edge(graph=AdaptedDGLGraph(self),
u=u, u=u, v=v, eid=eid,
v=v,
eid=eid,
apply_func=func, apply_func=func,
group_by=group_by, group_by=group_by,
inplace=inplace) inplace=inplace)
...@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid, scheduler.schedule_send(graph=AdaptedDGLGraph(self), u=u, v=v, eid=eid,
message_func=message_func) message_func=message_func)
Runtime.run(prog) Runtime.run(prog)
...@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_recv(graph=self, scheduler.schedule_recv(graph=AdaptedDGLGraph(self),
recv_nodes=v, recv_nodes=v,
reduce_func=reduce_func, reduce_func=reduce_func,
apply_func=apply_node_func, apply_func=apply_node_func,
...@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_snr(graph=self, scheduler.schedule_snr(graph=AdaptedDGLGraph(self),
edge_tuples=(u, v, eid), edge_tuples=(u, v, eid),
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph):
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_pull(graph=self, scheduler.schedule_pull(graph=AdaptedDGLGraph(self),
pull_nodes=v, pull_nodes=v,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph):
if len(u) == 0: if len(u) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_push(graph=self, scheduler.schedule_push(graph=AdaptedDGLGraph(self),
u=u, u=u,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph):
assert reduce_func is not None assert reduce_func is not None
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_update_all(graph=self, scheduler.schedule_update_all(graph=AdaptedDGLGraph(self),
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
apply_func=apply_node_func) apply_func=apply_node_func)
...@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph):
v = utils.toindex(nodes) v = utils.toindex(nodes)
n_repr = self.get_n_repr(v) n_repr = self.get_n_repr(v)
nbatch = NodeBatch(self, v, n_repr) nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu()) n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes): if is_all(nodes):
...@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph):
filter_nodes filter_nodes
""" """
if is_all(edges): if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges('eid') u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
...@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph):
src_data = self.get_n_repr(u) src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid) edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v) dst_data = self.get_n_repr(v)
ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data) ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu()) e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges): if is_all(edges):
...@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph): ...@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph):
yield yield
self._node_frame = old_nframe self._node_frame = old_nframe
self._edge_frame = old_eframe self._edge_frame = old_eframe
############################################################
# Internal APIs
############################################################
class AdaptedDGLGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLGraph
Graph
"""
def __init__(self, graph):
self.graph = graph
@property
def gidx(self):
return self.graph._graph
def num_src(self):
"""Number of source nodes."""
return self.graph.number_of_nodes()
def num_dst(self):
"""Number of destination nodes."""
return self.graph.number_of_nodes()
def num_edges(self):
"""Number of edges."""
return self.graph.number_of_edges()
@property
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frame
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frame
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index()
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(val)
def in_edges(self, nodes):
return self.graph._graph.in_edges(nodes)
def out_edges(self, nodes):
return self.graph._graph.out_edges(nodes)
def edges(self, form):
return self.graph._graph.edges(form)
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_immutable_gidx(ctx)
def bits_needed(self):
return self.graph._graph.bits_needed()
...@@ -1129,10 +1129,13 @@ def from_edge_list(elist, is_multigraph, readonly): ...@@ -1129,10 +1129,13 @@ def from_edge_list(elist, is_multigraph, readonly):
Parameters Parameters
--------- ---------
elist : list elist : list, tuple
List of (u, v) edge tuple. List of (u, v) edge tuple, or a tuple of src/dst lists
""" """
src, dst = zip(*elist) if isinstance(elist, tuple):
src, dst = elist
else:
src, dst = zip(*elist)
src = np.array(src) src = np.array(src)
dst = np.array(dst) dst = np.array(dst)
src_ids = utils.toindex(src) src_ids = utils.toindex(src)
......
"""Classes for heterogeneous graphs.""" """Classes for heterogeneous graphs."""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
import networkx as nx import networkx as nx
import scipy.sparse as ssp import numpy as np
from . import heterograph_index, graph_index
from . import graph_index
from . import heterograph_index
from . import utils from . import utils
from . import backend as F from . import backend as F
from . import init from . import init
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef from .frame import Frame, FrameRef, frame_like, sync_frame_initializer
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, is_all, DGLError from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError
__all__ = ['DGLHeteroGraph', 'combine_names']
class DGLHeteroGraph(object):
"""Base heterogeneous graph class.
Do NOT instantiate from this class directly; use :mod:`conversion methods
<dgl.convert>` instead.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> dev_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, dev_g])
:func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
data types including: edge list, edge tuples, networkx graph and scipy sparse matrix.
Click the function name for more details.
Then one can query the graph structure by specifying the ``ntype`` or ``etype`` arguments:
>>> g.number_of_nodes('user')
3
>>> g.number_of_edges('plays')
4
>>> g.out_degrees(etype='develops') # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g.in_edges(0, etype='develops') # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
Or on the sliced graph for an edge type:
__all__ = ['DGLHeteroGraph'] >>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees()
tensor([1, 1])
>>> g['develops'].in_edges(0)
Node type names must be distinct (no two types have the same name). Edge types could
have the same name but they must be distinguishable by the ``(src_type, edge_type, dst_type)``
triplet (called *canonical edge type*).
For example, suppose a graph that has two types of relation "user-watches-movie"
and "user-watches-TV" as follows:
>>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
>>> GG = dgl.hetero_from_relations([g0, g1])
To distinguish between the two "watches" edge type, one must specify a full triplet:
>>> GG.number_of_edges(('user', 'watches', 'movie'))
3
>>> GG.number_of_edges(('user', 'watches', 'TV'))
2
>>> GG['user', 'watches', 'movie'].out_degrees()
tensor([1, 2])
Using only one single edge type string "watches" is ambiguous and will cause error:
>>> GG.number_of_edges('watches') # AMBIGUOUS!!
In many cases, there is only one type of nodes or one type of edges, and the ``ntype``
and ``etype`` argument could be omitted. This is very common when using the sliced
graph, which usually contains only one edge type, and sometimes only one node type:
# TODO: depending on the progress of unifying DGLGraph and Bipartite, we may or may not >>> g['follows'].number_of_nodes() # OK!! because g['follows'] only has one node type 'user'
# need the code of heterogeneous graph views. 3
# pylint: disable=unnecessary-pass >>> g['plays'].number_of_nodes() # ERROR!! There are two types 'user' and 'game'.
class DGLBaseHeteroGraph(object): >>> g['plays'].number_of_edges() # OK!! because there is only one edge type 'plays'
"""Base Heterogeneous graph class.
Parameters Parameters
---------- ----------
graph : graph index, optional gidx : HeteroGraphIndex
The graph index Graph index object.
ntypes : list[str] ntypes : list of str
The node type names Node type list. The i^th element stores the type name
etypes : list[str] of node type i.
The edge type names etypes : list of str
_ntypes_invmap, _etypes_invmap, _view_ntype_idx, _view_etype_idx : Edge type list. The i^th element stores the type name
Internal arguments of edge type i.
""" node_frames : list of FrameRef, optional
Node feature storage. The i^th element stores the node features
of node type i. If None, empty frame is created. (default: None)
edge_frames : list of FrameRef, optional
Edge feature storage. The i^th element stores the edge features
of edge type i. If None, empty frame is created. (default: None)
multigraph : bool, optional
Whether the graph would be a multigraph. If none, the flag will be determined
by scanning the whole graph. (default: None)
readonly : bool, optional
Whether the graph structure is read-only (default: True).
Notes
-----
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def __init__(self, graph, ntypes, etypes, def __init__(self,
_ntypes_invmap=None, _etypes_invmap=None, gidx,
_view_ntype_idx=None, _view_etype_idx=None): ntypes,
super(DGLBaseHeteroGraph, self).__init__() etypes,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True):
assert readonly, "Only readonly heterogeneous graphs are supported"
self._graph = graph self._graph = gidx
self._nx_metagraph = None
self._ntypes = ntypes self._ntypes = ntypes
self._etypes = etypes self._etypes = etypes
# inverse mapping from ntype str to int self._canonical_etypes = make_canonical_etypes(etypes, ntypes, self._graph.metagraph)
self._ntypes_invmap = _ntypes_invmap or \ # An internal map from etype to canonical etype tuple.
{ntype: i for i, ntype in enumerate(ntypes)} # If two etypes have the same name, an empty tuple is stored instead to indicte ambiguity.
# inverse mapping from etype str to int self._etype2canonical = {}
self._etypes_invmap = _etypes_invmap or \ for i, ety in enumerate(etypes):
{etype: i for i, etype in enumerate(etypes)} if ety in self._etype2canonical:
self._etype2canonical[ety] = tuple()
else:
self._etype2canonical[ety] = self._canonical_etypes[i]
self._ntypes_invmap = {t : i for i, t in enumerate(ntypes)}
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
# Indicates which node/edge type (int) it is viewing. # node and edge frame
self._view_ntype_idx = _view_ntype_idx if node_frames is None:
self._view_etype_idx = _view_etype_idx node_frames = [None] * len(self._ntypes)
node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
if frame is None else frame
for i, frame in enumerate(node_frames)]
self._node_frames = node_frames
self._cache = {} if edge_frames is None:
edge_frames = [None] * len(self._etypes)
edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
if frame is None else frame
for i, frame in enumerate(edge_frames)]
self._edge_frames = edge_frames
def _create_view(self, ntype_idx, etype_idx): # message indicators
return DGLBaseHeteroGraph( self._msg_indices = [None] * len(self._etypes)
self._graph, self._ntypes, self._etypes, self._msg_frames = []
self._ntypes_invmap, self._etypes_invmap, for i in range(len(self._etypes)):
ntype_idx, etype_idx) frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
@property def _get_msg_index(self, etid):
def is_node_type_view(self): if self._msg_indices[etid] is None:
"""Whether this is a node type view of a heterograph.""" self._msg_indices[etid] = utils.zero_index(
return self._view_ntype_idx is not None size=self._graph.number_of_edges(etid))
return self._msg_indices[etid]
@property def _set_msg_index(self, etid, index):
def is_edge_type_view(self): self._msg_indices[etid] = index
"""Whether this is an edge type view of a heterograph."""
return self._view_etype_idx is not None
@property def __repr__(self):
def is_view(self): if len(self.ntypes) == 1 and len(self.etypes) == 1:
"""Whether this is a node/view of a heterograph.""" ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
return self.is_node_type_view or self.is_edge_type_view ' ndata_schemes={ndata}\n'
' edata_schemes={edata})')
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
else:
ret = ('Graph(num_nodes={node},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))}
nedge_dict = {self.etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))}
meta = str(self.metagraph.edges())
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
#################################################################
# Mutation operations
#################################################################
def add_nodes(self, num, data=None, ntype=None):
"""Add multiple new nodes of the same node type
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
def add_edge(self, u, v, data=None, etype=None):
"""Add an edge of ``etype`` between u of the source node type, and v
of the destination node type..
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
def add_edges(self, u, v, data=None, etype=None):
"""Add multiple edges of ``etype`` between list of source nodes ``u``
and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
#################################################################
# Metagraph query
#################################################################
@property @property
def all_node_types(self): def ntypes(self):
"""Return the list of node types of the entire heterograph.""" """Return the list of node types of this graph."""
return self._ntypes return self._ntypes
@property @property
def all_edge_types(self): def etypes(self):
"""Return the list of edge types of the entire heterograph.""" """Return the list of edge types of this graph."""
return self._etypes return self._etypes
@property
def canonical_etypes(self):
"""Return the list of canonical edge types of this graph.
A canonical edge type is a tuple of string (src_type, edge_type, dst_type).
"""
return self._canonical_etypes
@property @property
def metagraph(self): def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph. """Return the metagraph as networkx.MultiDiGraph.
The nodes are labeled with node type names. The nodes are labeled with node type names.
The edges have their keys holding the edge type names. The edges have their keys holding the edge type names.
Returns
-------
networkx.MultiDiGraph
"""
if self._nx_metagraph is None:
nx_graph = self._graph.metagraph.to_networkx()
self._nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
self._nx_metagraph.add_edge(srctype, dsttype, etype)
return self._nx_metagraph
def to_canonical_etype(self, etype):
"""Convert edge type to canonical etype: (srctype, etype, dsttype).
The input can already be a canonical tuple.
Parameters
----------
etype : str or tuple of str
Edge type
Returns
-------
tuple of str
""" """
nx_graph = self._graph.metagraph.to_networkx() if isinstance(etype, tuple):
nx_return_graph = nx.MultiDiGraph() return etype
for u_v in nx_graph.edges:
etype = self._etypes[nx_graph.edges[u_v]['id']]
srctype = self._ntypes[u_v[0]]
dsttype = self._ntypes[u_v[1]]
assert etype[0] == srctype
assert etype[2] == dsttype
nx_return_graph.add_edge(srctype, dsttype, etype[1])
return nx_return_graph
def _endpoint_types(self, etype):
"""Return the source and destination node type (int) of given edge
type (int)."""
return self._graph.metagraph.find_edge(etype)
def _node_types(self):
if self.is_node_type_view:
return [self._view_ntype_idx]
elif self.is_edge_type_view:
srctype_idx, dsttype_idx = self._endpoint_types(self._view_etype_idx)
return [srctype_idx, dsttype_idx] if srctype_idx != dsttype_idx else [srctype_idx]
else: else:
return range(len(self._ntypes)) ret = self._etype2canonical.get(etype, None)
if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical etype '
'type in the form of (srctype, etype, dsttype)' % etype)
return ret
def get_ntype_id(self, ntype):
"""Return the id of the given node type.
ntype can also be None. If so, there should be only one node type in the
graph.
def node_types(self): Parameters
"""Return the list of node types appearing in the current view. ----------
ntype : str
Node type
Returns Returns
------- -------
list[str] int
List of node types """
if ntype is None:
if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
return 0
ntid = self._ntypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('Node type "{}" does not exist.'.format(ntype))
return ntid
Examples def get_etype_id(self, etype):
-------- """Return the id of the given edge type.
Getting all node types.
>>> g.node_types()
['user', 'game', 'developer']
Getting all node types appearing in the subgraph induced by "users"
(which should only yield "user").
>>> g['user'].node_types()
['user']
The node types appearing in subgraph induced by "plays" relationship,
which should only give "user" and "game".
>>> g['plays'].node_types()
['user', 'game']
"""
ntypes = self._node_types()
if isinstance(ntypes, range):
# assuming that the range object always covers the entire node type list
return self._ntypes
else:
return [self._ntypes[i] for i in ntypes]
def _edge_types(self):
if self.is_node_type_view:
etype_indices = self._graph.metagraph.edge_id(
self._view_ntype_idx, self._view_ntype_idx)
return etype_indices
elif self.is_edge_type_view:
return [self._view_etype_idx]
else:
return range(len(self._etypes))
def edge_types(self): etype can also be None. If so, there should be only one edge type in the
"""Return the list of edge types appearing in the current view. graph.
Parameters
----------
etype : str or tuple of str
Edge type
Returns Returns
------- -------
list[str] int
List of edge types """
if etype is None:
if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
return 0
etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
if etid is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
return etid
Examples #################################################################
-------- # View
Getting all edge types. #################################################################
>>> g.edge_types()
['follows', 'plays', 'develops']
Getting all edge types appearing in subgraph induced by "users".
>>> g['user'].edge_types()
['follows']
The edge types appearing in subgraph induced by "plays" relationship,
which should only give "plays".
>>> g['plays'].edge_types()
['plays']
"""
etypes = self._edge_types()
if isinstance(etypes, range):
return self._etypes
else:
return [self._etypes[i] for i in etypes]
@property @property
@utils.cached_member('_cache', '_current_ntype_idx') def nodes(self):
def _current_ntype_idx(self): """Return a node view that can used to set/get feature data of a
"""Checks the uniqueness of node type in the view and get the index single node type.
of that node type.
This allows reading/writing node frame data. Examples
--------
To set features of all Users:
>>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
""" """
node_types = self._node_types() return HeteroNodeView(self)
assert len(node_types) == 1, "only available for subgraphs with one node type"
return node_types[0]
@property @property
@utils.cached_member('_cache', '_current_etype_idx') def ndata(self):
def _current_etype_idx(self): """Return the data view of all the nodes.
"""Checks the uniqueness of edge type in the view and get the index
of that edge type.
This allows reading/writing edge frame data and message passing routines. Only works if the graph has only one node type.
Examples
--------
To set features of all nodes in a heterogeneous graph with only one node type:
>>> g.ndata['h'] = torch.zeros(2, 5)
""" """
edge_types = self._edge_types() return HeteroNodeDataView(self, None, ALL)
assert len(edge_types) == 1, "only available for subgraphs with one edge type"
return edge_types[0]
@property @property
@utils.cached_member('_cache', '_current_srctype_idx') def edges(self):
def _current_srctype_idx(self): """Return an edges view that can used to set/get feature data of a
"""Checks the uniqueness of edge type in the view and get the index single edge type.
of the source type.
This allows reading/writing edge frame data and message passing routines. Examples
--------
To set features of all "play" relationships:
>>> g.edges['plays'].data['h'] = torch.zeros(4, 4)
""" """
srctype_idx, _ = self._endpoint_types(self._current_etype_idx) return HeteroEdgeView(self)
return srctype_idx
@property @property
@utils.cached_member('_cache', '_current_dsttype_idx') def edata(self):
def _current_dsttype_idx(self): """Return the data view of all the edges.
"""Checks the uniqueness of edge type in the view and get the index
of the destination type.
This allows reading/writing edge frame data and message passing routines. Only works if the graph has only one edge type
Examples
--------
To set features of all edges in a heterogeneous graph with only one edge type:
>>> g.edata['h'] = torch.zeros(2, 5)
""" """
_, dsttype_idx = self._endpoint_types(self._current_etype_idx) return HeteroEdgeDataView(self, None, ALL)
return dsttype_idx
def _find_etypes(self, key):
etypes = [
i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
(key[0] == SLICE_FULL or key[0] == srctype) and
(key[1] == SLICE_FULL or key[1] == etype) and
(key[2] == SLICE_FULL or key[2] == dsttype)]
return etypes
def __getitem__(self, key):
"""Return the relation slice of this graph.
A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
slice (``:``) representing wildcard (i.e. any source/edge/destination type).
A relation slice is a homogeneous (with one node type and one edge type) or
bipartite (with two node types and one edge type) graph, transformed from
the original heterogeneous graph.
If there is only one canonical edge type found, then the returned relation
slice would be a subgraph induced from the original graph. That is, it is
equivalent to ``self.edge_type_subgraph(etype)``. The node and edge features
of the returned graph would be shared with thew original graph.
If there are multiple canonical edge type found, then the source/edge/destination
node types would be a *concatenation* of original node/edge types. The
new source/destination node type would have the concatenation determined by
:func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
types as its name. The source/destination node would be formed by concatenating the
common features of the original source/destination types, therefore they are not
shared with the original graph. Edge type is similar.
"""
err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
"to get view of one relation type. Use : to slice multiple types (e.g. " +\
"G['srctype', :, 'dsttype'])."
if not isinstance(key, tuple):
key = (SLICE_FULL, key, SLICE_FULL)
if len(key) != 3:
raise DGLError(err_msg)
etypes = self._find_etypes(key)
if len(etypes) == 1:
# no ambiguity: return the unitgraph itself
srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
stid = self.get_ntype_id(srctype)
etid = self.get_etype_id((srctype, etype, dsttype))
dtid = self.get_ntype_id(dsttype)
new_g = self._graph.get_relation_graph(etid)
if stid == dtid:
new_ntypes = [srctype]
new_nframes = [self._node_frames[stid]]
else:
new_ntypes = [srctype, dsttype]
new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
new_etypes = [etype]
new_eframes = [self._edge_frames[etid]]
def number_of_nodes(self, ntype): return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
else:
flat = self._graph.flatten_relations(etypes)
new_g = flat.graph
# merge frames
stids = flat.induced_srctype_set.asnumpy()
dtids = flat.induced_dsttype_set.asnumpy()
etids = flat.induced_etype_set.asnumpy()
new_ntypes = [combine_names(self.ntypes, stids)]
if new_g.number_of_ntypes() == 2:
new_ntypes.append(combine_names(self.ntypes, dtids))
new_nframes = [
combine_frames(self._node_frames, stids),
combine_frames(self._node_frames, dtids)]
else:
assert np.array_equal(stids, dtids)
new_nframes = [combine_frames(self._node_frames, stids)]
new_etypes = [combine_names(self.etypes, etids)]
new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph
new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
# put the parent node/edge type and IDs
new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)
return new_hg
#################################################################
# Graph query
#################################################################
def number_of_nodes(self, ntype=None):
"""Return the number of nodes of the given type in the heterograph. """Return the number of nodes of the given type in the heterograph.
Parameters Parameters
---------- ----------
ntype : str ntype : str, optional
The node type The node type. Can be omitted if there is only one node type
in the graph.
Returns Returns
------- -------
...@@ -250,40 +554,16 @@ class DGLBaseHeteroGraph(object): ...@@ -250,40 +554,16 @@ class DGLBaseHeteroGraph(object):
>>> g['user'].number_of_nodes() >>> g['user'].number_of_nodes()
3 3
""" """
return self._graph.number_of_nodes(self._ntypes_invmap[ntype]) return self._graph.number_of_nodes(self.get_ntype_id(ntype))
def _number_of_src_nodes(self): def number_of_edges(self, etype=None):
"""Return number of source nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_srctype_idx)
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_dsttype_idx)
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise.
"""
assert not self.is_view, 'not supported on views'
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self._graph.number_of_edges(self._current_etype_idx)
def number_of_edges(self, etype):
"""Return the number of edges of the given type in the heterograph. """Return the number of edges of the given type in the heterograph.
Parameters Parameters
---------- ----------
etype : (str, str, str) etype : str or tuple of str, optional
The source-edge-destination type triplet The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -295,17 +575,28 @@ class DGLBaseHeteroGraph(object): ...@@ -295,17 +575,28 @@ class DGLBaseHeteroGraph(object):
>>> g.number_of_edges(('user', 'plays', 'game')) >>> g.number_of_edges(('user', 'plays', 'game'))
4 4
""" """
return self._graph.number_of_edges(self._etypes_invmap[etype]) return self._graph.number_of_edges(self.get_etype_id(etype))
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise."""
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise."""
return self._graph.is_readonly()
def has_node(self, ntype, vid): def has_node(self, vid, ntype=None):
"""Return True if the graph contains node `vid` of type `ntype`. """Return True if the graph contains node `vid` of type `ntype`.
Parameters Parameters
---------- ----------
ntype : str
The node type.
vid : int vid : int
The node ID. The node ID.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns Returns
------- -------
...@@ -314,28 +605,29 @@ class DGLBaseHeteroGraph(object): ...@@ -314,28 +605,29 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
>>> g.has_node('user', 0) >>> g.has_node(0, 'user')
True True
>>> g.has_node('user', 4) >>> g.has_node(4, 'user')
False False
See Also See Also
-------- --------
has_nodes has_nodes
""" """
return self._graph.has_node(self._ntypes_invmap[ntype], vid) return self._graph.has_node(self.get_ntype_id(ntype), vid)
def has_nodes(self, ntype, vids): def has_nodes(self, vids, ntype=None):
"""Return a 0-1 array ``a`` given the node ID array ``vids``. """Return a 0-1 array ``a`` given the node ID array ``vids``.
``a[i]`` is 1 if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise. ``a[i]`` is 1 if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Parameters Parameters
---------- ----------
ntype : str
The node type.
vid : list or tensor vid : list or tensor
The array of node IDs. The array of node IDs.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns Returns
------- -------
...@@ -346,7 +638,7 @@ class DGLBaseHeteroGraph(object): ...@@ -346,7 +638,7 @@ class DGLBaseHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g.has_nodes('user', [0, 1, 2, 3, 4]) >>> g.has_nodes([0, 1, 2, 3, 4], 'user')
tensor([1, 1, 1, 0, 0]) tensor([1, 1, 1, 0, 0])
See Also See Also
...@@ -354,20 +646,21 @@ class DGLBaseHeteroGraph(object): ...@@ -354,20 +646,21 @@ class DGLBaseHeteroGraph(object):
has_node has_node
""" """
vids = utils.toindex(vids) vids = utils.toindex(vids)
rst = self._graph.has_nodes(self._ntypes_invmap[ntype], vids) rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
return rst.tousertensor() return rst.tousertensor()
def has_edge_between(self, etype, u, v): def has_edge_between(self, u, v, etype=None):
"""Return True if the edge (u, v) of type ``etype`` is in the graph. """Return True if the edge (u, v) of type ``etype`` is in the graph.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int u : int
The node ID of source type. The node ID of source type.
v : int v : int
The node ID of destination type. The node ID of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -377,20 +670,20 @@ class DGLBaseHeteroGraph(object): ...@@ -377,20 +670,20 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
Check whether Alice plays Tetris Check whether Alice plays Tetris
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 1) >>> g.has_edge_between(0, 1, ('user', 'plays', 'game'))
True True
And whether Alice plays Minecraft And whether Alice plays Minecraft
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 2) >>> g.has_edge_between(0, 2, ('user', 'plays', 'game'))
False False
See Also See Also
-------- --------
has_edges_between has_edges_between
""" """
return self._graph.has_edge_between(self._etypes_invmap[etype], u, v) return self._graph.has_edge_between(self.get_etype_id(etype), u, v)
def has_edges_between(self, etype, u, v): def has_edges_between(self, u, v, etype=None):
"""Return a 0-1 array ``a`` given the source node ID array ``u`` and """Return a 0-1 array ``a`` given the source node ID array ``u`` and
destination node ID array ``v``. destination node ID array ``v``.
...@@ -398,12 +691,13 @@ class DGLBaseHeteroGraph(object): ...@@ -398,12 +691,13 @@ class DGLBaseHeteroGraph(object):
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor u : list, tensor
The node ID array of source type. The node ID array of source type.
v : list, tensor v : list, tensor
The node ID array of destination type. The node ID array of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -414,7 +708,7 @@ class DGLBaseHeteroGraph(object): ...@@ -414,7 +708,7 @@ class DGLBaseHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g.has_edges_between(('user', 'plays', 'game'), [0, 0], [1, 2]) >>> g.has_edges_between([0, 0], [1, 2], ('user', 'plays', 'game'))
tensor([1, 0]) tensor([1, 0])
See Also See Also
...@@ -423,10 +717,10 @@ class DGLBaseHeteroGraph(object): ...@@ -423,10 +717,10 @@ class DGLBaseHeteroGraph(object):
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
rst = self._graph.has_edges_between(self._etypes_invmap[etype], u, v) rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
return rst.tousertensor() return rst.tousertensor()
def predecessors(self, etype, v): def predecessors(self, v, etype=None):
"""Return the predecessors of node `v` in the graph with the same """Return the predecessors of node `v` in the graph with the same
edge type. edge type.
...@@ -435,10 +729,11 @@ class DGLBaseHeteroGraph(object): ...@@ -435,10 +729,11 @@ class DGLBaseHeteroGraph(object):
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node of destination type. The node of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -450,7 +745,7 @@ class DGLBaseHeteroGraph(object): ...@@ -450,7 +745,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Query who plays Tetris: Query who plays Tetris:
>>> g.predecessors(('user', 'plays', 'game'), 0) >>> g.predecessors(0, ('user', 'plays', 'game'))
tensor([0, 1]) tensor([0, 1])
This indicates User #0 (Alice) and User #1 (Bob). This indicates User #0 (Alice) and User #1 (Bob).
...@@ -459,9 +754,9 @@ class DGLBaseHeteroGraph(object): ...@@ -459,9 +754,9 @@ class DGLBaseHeteroGraph(object):
-------- --------
successors successors
""" """
return self._graph.predecessors(self._etypes_invmap[etype], v).tousertensor() return self._graph.predecessors(self.get_etype_id(etype), v).tousertensor()
def successors(self, etype, v): def successors(self, v, etype=None):
"""Return the successors of node `v` in the graph with the same edge """Return the successors of node `v` in the graph with the same edge
type. type.
...@@ -470,10 +765,11 @@ class DGLBaseHeteroGraph(object): ...@@ -470,10 +765,11 @@ class DGLBaseHeteroGraph(object):
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node of source type. The node of source type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -485,7 +781,7 @@ class DGLBaseHeteroGraph(object): ...@@ -485,7 +781,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Asks which game Alice plays: Asks which game Alice plays:
>>> g.successors(('user', 'plays', 'game'), 0) >>> g.successors(0, ('user', 'plays', 'game'))
tensor([0]) tensor([0])
This indicates Game #0 (Tetris). This indicates Game #0 (Tetris).
...@@ -494,26 +790,24 @@ class DGLBaseHeteroGraph(object): ...@@ -494,26 +790,24 @@ class DGLBaseHeteroGraph(object):
-------- --------
predecessors predecessors
""" """
return self._graph.successors(self._etypes_invmap[etype], v).tousertensor() return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
def edge_id(self, etype, u, v, force_multi=False): def edge_id(self, u, v, force_multi=False, etype=None):
"""Return the edge ID, or an array of edge IDs, between source node """Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`. `u` and destination node `v`.
Only works if the graph has one edge type. For multiple types,
query with
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int u : int
The node ID of source type. The node ID of source type.
v : int v : int
The node ID of destination type. The node ID of destination type.
force_multi : bool force_multi : bool, optional
If False, will return a single edge ID if the graph is a simple graph. If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array. If True, will always return an array. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -526,33 +820,31 @@ class DGLBaseHeteroGraph(object): ...@@ -526,33 +820,31 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the edge ID of "Bob plays Tetris" Find the edge ID of "Bob plays Tetris"
>>> g.edge_id(('user', 'plays', 'game'), 1, 0) >>> g.edge_id(1, 0, etype=('user', 'plays', 'game'))
1 1
See Also See Also
-------- --------
edge_ids edge_ids
""" """
idx = self._graph.edge_id(self._etypes_invmap[etype], u, v) idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0] return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
def edge_ids(self, etype, u, v, force_multi=False): def edge_ids(self, u, v, force_multi=False, etype=None):
"""Return all edge IDs between source node array `u` and destination """Return all edge IDs between source node array `u` and destination
node array `v`. node array `v`.
Only works if the graph has one edge type. For multiple types,
query with
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor u : list, tensor
The node ID array of source type. The node ID array of source type.
v : list, tensor v : list, tensor
The node ID array of destination type. The node ID array of destination type.
force_multi : bool force_multi : bool, optional
Whether to always treat the graph as a multigraph. Whether to always treat the graph as a multigraph. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -574,7 +866,7 @@ class DGLBaseHeteroGraph(object): ...@@ -574,7 +866,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft". Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft".
>>> g.edge_ids(('user', 'plays', 'game'), [0, 1], [0, 1]) >>> g.edge_ids([0, 1], [0, 1], etype=('user', 'plays', 'game'))
tensor([0, 2]) tensor([0, 2])
See Also See Also
...@@ -583,23 +875,24 @@ class DGLBaseHeteroGraph(object): ...@@ -583,23 +875,24 @@ class DGLBaseHeteroGraph(object):
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self._etypes_invmap[etype], u, v) src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
if force_multi or self._graph.is_multigraph(): if force_multi or self._graph.is_multigraph():
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else: else:
return eid.tousertensor() return eid.tousertensor()
def find_edges(self, etype, eid): def find_edges(self, eid, etype=None):
"""Given an edge ID array, return the source and destination node ID """Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`. ID for edge `eid[i]`.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
eid : list, tensor eid : list, tensor
The edge ID array. The edge ID array.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -613,20 +906,18 @@ class DGLBaseHeteroGraph(object): ...@@ -613,20 +906,18 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the user and game of gameplay #0 and #2: Find the user and game of gameplay #0 and #2:
>>> g.find_edges(('user', 'plays', 'game'), [0, 2]) >>> g.find_edges([0, 2], ('user', 'plays', 'game'))
(tensor([0, 1]), tensor([0, 1])) (tensor([0, 1]), tensor([0, 1]))
""" """
eid = utils.toindex(eid) eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(self._etypes_invmap[etype], eid) src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src.tousertensor(), dst.tousertensor() return src.tousertensor(), dst.tousertensor()
def in_edges(self, etype, v, form='uv'): def in_edges(self, v, form='uv', etype=None):
"""Return the inbound edges of the node(s). """Return the inbound edges of the node(s).
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor v : int, list, tensor
The node(s) of destination type. The node(s) of destination type.
form : str, optional form : str, optional
...@@ -635,6 +926,9 @@ class DGLBaseHeteroGraph(object): ...@@ -635,6 +926,9 @@ class DGLBaseHeteroGraph(object):
- 'all' : a tuple (u, v, eid) - 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default - 'uv' : a pair (u, v), default
- 'eid' : one eid tensor - 'eid' : one eid tensor
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -652,11 +946,11 @@ class DGLBaseHeteroGraph(object): ...@@ -652,11 +946,11 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the gameplay IDs of game #0 (Tetris) Find the gameplay IDs of game #0 (Tetris)
>>> g.in_edges(('user', 'plays', 'game'), 0, 'eid') >>> g.in_edges(0, 'eid', ('user', 'plays', 'game'))
tensor([0, 1]) tensor([0, 1])
""" """
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(self._etypes_invmap[etype], v) src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv': elif form == 'uv':
...@@ -666,13 +960,11 @@ class DGLBaseHeteroGraph(object): ...@@ -666,13 +960,11 @@ class DGLBaseHeteroGraph(object):
else: else:
raise DGLError('Invalid form:', form) raise DGLError('Invalid form:', form)
def out_edges(self, etype, v, form='uv'): def out_edges(self, v, form='uv', etype=None):
"""Return the outbound edges of the node(s). """Return the outbound edges of the node(s).
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor v : int, list, tensor
The node(s) of source type. The node(s) of source type.
form : str, optional form : str, optional
...@@ -681,6 +973,9 @@ class DGLBaseHeteroGraph(object): ...@@ -681,6 +973,9 @@ class DGLBaseHeteroGraph(object):
- 'all' : a tuple (u, v, eid) - 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default - 'uv' : a pair (u, v), default
- 'eid' : one eid tensor - 'eid' : one eid tensor
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -698,11 +993,11 @@ class DGLBaseHeteroGraph(object): ...@@ -698,11 +993,11 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the gameplay IDs of user #0 (Alice) Find the gameplay IDs of user #0 (Alice)
>>> g.out_edges(('user', 'plays', 'game'), 0, 'eid') >>> g.out_edges(0, 'eid', ('user', 'plays', 'game'))
tensor([0]) tensor([0])
""" """
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(self._etypes_invmap[etype], v) src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), v)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv': elif form == 'uv':
...@@ -712,13 +1007,11 @@ class DGLBaseHeteroGraph(object): ...@@ -712,13 +1007,11 @@ class DGLBaseHeteroGraph(object):
else: else:
raise DGLError('Invalid form:', form) raise DGLError('Invalid form:', form)
def all_edges(self, etype, form='uv', order=None): def all_edges(self, form='uv', order=None, etype=None):
"""Return all the edges. """Return all the edges.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
form : str, optional form : str, optional
The return form. Currently support: The return form. Currently support:
...@@ -731,6 +1024,9 @@ class DGLBaseHeteroGraph(object): ...@@ -731,6 +1024,9 @@ class DGLBaseHeteroGraph(object):
- 'srcdst' : sorted by their src and dst ids. - 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids. - 'eid' : sorted by edge Ids.
- None : the arbitrary order. - None : the arbitrary order.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -749,10 +1045,10 @@ class DGLBaseHeteroGraph(object): ...@@ -749,10 +1045,10 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find the user-game pairs for all gameplays: Find the user-game pairs for all gameplays:
>>> g.all_edges(('user', 'plays', 'game'), 'uv') >>> g.all_edges('uv', etype=('user', 'plays', 'game'))
(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1])) (tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1]))
""" """
src, dst, eid = self._graph.edges(self._etypes_invmap[etype], order) src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv': elif form == 'uv':
...@@ -762,15 +1058,16 @@ class DGLBaseHeteroGraph(object): ...@@ -762,15 +1058,16 @@ class DGLBaseHeteroGraph(object):
else: else:
raise DGLError('Invalid form:', form) raise DGLError('Invalid form:', form)
def in_degree(self, etype, v): def in_degree(self, v, etype=None):
"""Return the in-degree of node ``v``. """Return the in-degree of node ``v``.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node ID of destination type. The node ID of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -782,27 +1079,28 @@ class DGLBaseHeteroGraph(object): ...@@ -782,27 +1079,28 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
Find how many users are playing Game #0 (Tetris): Find how many users are playing Game #0 (Tetris):
>>> g.in_degree(('user', 'plays', 'game'), 0) >>> g.in_degree(0, ('user', 'plays', 'game'))
2 2
See Also See Also
-------- --------
in_degrees in_degrees
""" """
return self._graph.in_degree(self._etypes_invmap[etype], v) return self._graph.in_degree(self.get_etype_id(etype), v)
def in_degrees(self, etype, v=ALL): def in_degrees(self, v=ALL, etype=None):
"""Return the array `d` of in-degrees of the node array `v`. """Return the array `d` of in-degrees of the node array `v`.
`d[i]` is the in-degree of node `v[i]`. `d[i]` is the in-degree of node `v[i]`.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor, optional. v : list, tensor, optional.
The node ID array of destination type. Default is to return the The node ID array of destination type. Default is to return the
degrees of all the nodes. degrees of all the nodes.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -814,30 +1112,31 @@ class DGLBaseHeteroGraph(object): ...@@ -814,30 +1112,31 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find how many users are playing Game #0 and #1 (Tetris and Minecraft): Find how many users are playing Game #0 and #1 (Tetris and Minecraft):
>>> g.in_degrees(('user', 'plays', 'game'), [0, 1]) >>> g.in_degrees([0, 1], ('user', 'plays', 'game'))
tensor([2, 2]) tensor([2, 2])
See Also See Also
-------- --------
in_degree in_degree
""" """
etype_idx = self._etypes_invmap[etype] etid = self.get_etype_id(etype)
_, dsttype_idx = self._endpoint_types(etype_idx) _, dtid = self._graph.metagraph.find_edge(etid)
if is_all(v): if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(dsttype_idx))) v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)))
else: else:
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.in_degrees(etype_idx, v).tousertensor() return self._graph.in_degrees(etid, v).tousertensor()
def out_degree(self, etype, v): def out_degree(self, v, etype=None):
"""Return the out-degree of node `v`. """Return the out-degree of node `v`.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int v : int
The node ID of source type. The node ID of source type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -847,27 +1146,28 @@ class DGLBaseHeteroGraph(object): ...@@ -847,27 +1146,28 @@ class DGLBaseHeteroGraph(object):
Examples Examples
-------- --------
Find how many games User #0 Alice is playing Find how many games User #0 Alice is playing
>>> g.out_degree(('user', 'plays', 'game'), 0) >>> g.out_degree(0, ('user', 'plays', 'game'))
1 1
See Also See Also
-------- --------
out_degrees out_degrees
""" """
return self._graph.out_degree(self._etypes_invmap[etype], v) return self._graph.out_degree(self.get_etype_id(etype), v)
def out_degrees(self, etype, v=ALL): def out_degrees(self, v=ALL, etype=None):
"""Return the array `d` of out-degrees of the node array `v`. """Return the array `d` of out-degrees of the node array `v`.
`d[i]` is the out-degree of node `v[i]`. `d[i]` is the out-degree of node `v[i]`.
Parameters Parameters
---------- ----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor v : list, tensor
The node ID array of source type. Default is to return the degrees The node ID array of source type. Default is to return the degrees
of all the nodes. of all the nodes.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -879,462 +1179,294 @@ class DGLBaseHeteroGraph(object): ...@@ -879,462 +1179,294 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
Find how many games User #0 and #1 (Alice and Bob) are playing Find how many games User #0 and #1 (Alice and Bob) are playing
>>> g.out_degrees(('user', 'plays', 'game'), [0, 1]) >>> g.out_degrees([0, 1], ('user', 'plays', 'game'))
tensor([1, 2]) tensor([1, 2])
See Also See Also
-------- --------
out_degree out_degree
""" """
etype_idx = self._etypes_invmap[etype] etid = self.get_etype_id(etype)
srctype_idx, _ = self._endpoint_types(etype_idx) stid, _ = self._graph.metagraph.find_edge(etid)
if is_all(v): if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(srctype_idx))) v = utils.toindex(slice(0, self._graph.number_of_nodes(stid)))
else: else:
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.out_degrees(etype_idx, v).tousertensor() return self._graph.out_degrees(etid, v).tousertensor()
def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
def bipartite_from_edge_list(u, v, num_src=None, num_dst=None): """Internal function to create a subgraph."""
"""Create a bipartite graph component of a heterogeneous graph with a node_frames = [
list of edges. FrameRef(Frame(
self._node_frames[i][induced_nodes_of_ntype],
Parameters num_rows=len(induced_nodes_of_ntype)))
---------- for i, induced_nodes_of_ntype in enumerate(induced_nodes)]
u, v : list[int] edge_frames = [
List of source and destination node IDs. FrameRef(Frame(
num_src : int, optional self._edge_frames[i][induced_edges_of_etype],
The number of nodes of source type. num_rows=len(induced_edges_of_etype)))
for i, induced_edges_of_etype in enumerate(induced_edges)]
By default, the value is the maximum of the source node IDs in the
edge list plus 1. hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
num_dst : int, optional hsg.is_subgraph = True
The number of nodes of destination type. for ntype, induced_nid in zip(self.ntypes, induced_nodes):
hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
By default, the value is the maximum of the destination node IDs in for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
the edge list plus 1. hsg.edges[etype].data[EID] = induced_eid.tousertensor()
"""
num_src = num_src or (max(u) + 1) return hsg
num_dst = num_dst or (max(v) + 1)
u = utils.toindex(u)
v = utils.toindex(v)
return heterograph_index.create_bipartite_from_coo(num_src, num_dst, u, v)
def bipartite_from_scipy(spmat, with_edge_id=False): def subgraph(self, nodes):
"""Create a bipartite graph component of a heterogeneous graph with a """Return the subgraph induced on given nodes.
scipy sparse matrix.
Parameters The metagraph of the returned subgraph is the same as the parent graph.
----------
spmat : scipy sparse matrix
The bipartite graph matrix whose rows represent sources and columns
represent destinations.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
"""
spmat = spmat.tocsr()
num_src, num_dst = spmat.shape
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
return heterograph_index.create_bipartite_from_csr(num_src, num_dst, indptr, indices, data)
Features are copied from the original graph.
class DGLHeteroGraph(DGLBaseHeteroGraph): Examples
"""Base heterogeneous graph class. --------
TBD
A Heterogeneous graph is defined as a graph with node types and edge Parameters
types. ----------
nodes : dict[str, list or iterable]
A dictionary of node types to node ID array to construct
subgraph.
All nodes must exist in the graph.
If two edges share the same edge type, then their source nodes, as well Returns
as their destination nodes, also have the same type (the source node -------
types don't have to be the same as the destination node types). G : DGLHeteroGraph
The subgraph.
Parameters The nodes are relabeled so that node `i` of type `t` in the
---------- subgraph is mapped to the ``nodes[i]`` of type `t` in the
graph_data : original graph.
The graph data. It can be one of the followings: The edges are also relabeled.
One can retrieve the mapping from subgraph node/edge ID to parent
* (nx.MultiDiGraph, dict[str, list[tuple[int, int]]]) node/edge ID via `dgl.NID` and `dgl.EID` node/edge features of the
* (nx.MultiDiGraph, dict[str, scipy.sparse.matrix]) subgraph.
"""
The first element is the metagraph of the heterogeneous graph, as a induced_nodes = [utils.toindex(nodes.get(ntype, [])) for ntype in self.ntypes]
networkx directed graph. Its nodes represent the node types, and sgi = self._graph.node_subgraph(induced_nodes)
its edges represent the edge types. The edge type name should be induced_edges = sgi.induced_edges
stored as edge keys.
The second element is a mapping from edge type to edge list. The
edge list can be either a list of (u, v) pairs, or a scipy sparse
matrix whose rows represents sources and columns represents
destinations. The edges will be added in the (source, destination)
order.
node_frames : dict[str, dict[str, Tensor]]
The node frames for each node type
edge_frames : dict[str, dict[str, Tensor]]
The edge frames for each edge type
multigraph : bool
Whether the heterogeneous graph is a multigraph.
readonly : bool
Whether the heterogeneous graph is readonly.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G { return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following def edge_subgraph(self, edges, preserve_nodes=False):
IDs: """Return the subgraph induced on given edges.
User name Alice Bob Carol The metagraph of the returned subgraph is the same as the parent graph.
User ID 0 1 2
Game name Tetris Minecraft Features are copied from the original graph.
Game ID 0 1
Developer name Nintendo Mojang Examples
Developer ID 0 1 --------
TBD
One can construct the graph as follows: Parameters
----------
edges : dict[etype, list or iterable]
A dictionary of edge types to edge ID array to construct
subgraph.
All edges must exist in the subgraph.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
>>> mg = nx.MultiDiGraph([('user', 'user', 'follows'), Returns
... ('user', 'game', 'plays'), -------
... ('developer', 'game', 'develops')]) G : DGLHeteroGraph
>>> g = DGLHeteroGraph( The subgraph.
... mg, { The edges are relabeled so that edge `i` of type `t` in the
... 'follows': [(0, 1), (1, 2)], subgraph is mapped to the ``edges[i]`` of type `t` in the
... 'plays': [(0, 0), (1, 0), (1, 1), (2, 1)], original graph.
... 'develops': [(0, 0), (1, 1)]}) One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `dgl.NID` and `dgl.EID` node/edge features of the
subgraph.
"""
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [
utils.toindex(edges.get(canonical_etype, []))
for canonical_etype in self.canonical_etypes]
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes
Then one can query the graph structure as follows: return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
>>> g['user'].number_of_nodes() def node_type_subgraph(self, ntypes):
3 """Return the subgraph induced on given node types.
>>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees() # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g['develops'].in_edges(0) # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
Notes The metagraph of the returned subgraph is the subgraph of the original metagraph
----- induced from the node types.
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument
def __init__(
self,
graph_data=None,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True,
_view_ntype_idx=None,
_view_etype_idx=None):
assert readonly, "Only readonly heterogeneous graphs are supported"
# Creating a view of another graph? Features are shared with the original graph.
if isinstance(graph_data, DGLHeteroGraph):
super(DGLHeteroGraph, self).__init__(
graph_data._graph, graph_data._ntypes, graph_data._etypes,
graph_data._ntypes_invmap, graph_data._etypes_invmap,
graph_data._view_ntype_idx, graph_data._view_etype_idx)
self._node_frames = graph_data._node_frames
self._edge_frames = graph_data._edge_frames
self._msg_frames = graph_data._msg_frames
self._msg_indices = graph_data._msg_indices
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
return
if isinstance(graph_data, tuple): Examples
metagraph, edges_by_type = graph_data --------
if not isinstance(metagraph, nx.MultiDiGraph): TBD
raise TypeError('Metagraph should be networkx.MultiDiGraph')
# create metagraph graph index
srctypes, dsttypes, etypes = [], [], []
ntypes = []
ntypes_invmap = {}
etypes_invmap = {}
for srctype, dsttype, etype in metagraph.edges(keys=True):
srctypes.append(srctype)
dsttypes.append(dsttype)
etypes_invmap[(srctype, etype, dsttype)] = len(etypes_invmap)
etypes.append((srctype, etype, dsttype))
if srctype not in ntypes_invmap:
ntypes_invmap[srctype] = len(ntypes_invmap)
ntypes.append(srctype)
if dsttype not in ntypes_invmap:
ntypes_invmap[dsttype] = len(ntypes_invmap)
ntypes.append(dsttype)
srctypes = [ntypes_invmap[srctype] for srctype in srctypes]
dsttypes = [ntypes_invmap[dsttype] for dsttype in dsttypes]
metagraph_index = graph_index.create_graph_index(
list(zip(srctypes, dsttypes)), None, True) # metagraph is always immutable
# create base bipartites
bipartites = []
num_nodes = defaultdict(int)
# count the number of nodes for each type
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
num_src, num_dst = edges.shape
elif isinstance(edges, list):
u, v = zip(*edges)
num_src = max(u) + 1
num_dst = max(v) + 1
else:
raise TypeError('unknown edge list type %s' % type(edges))
num_nodes[srctype] = max(num_nodes[srctype], num_src)
num_nodes[dsttype] = max(num_nodes[dsttype], num_dst)
# create actual objects
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
bipartite = bipartite_from_scipy(edges)
elif isinstance(edges, list):
u, v = zip(*edges)
bipartite = bipartite_from_edge_list(
u, v, num_nodes[srctype], num_nodes[dsttype])
bipartites.append(bipartite)
hg_index = heterograph_index.create_heterograph(metagraph_index, bipartites)
super(DGLHeteroGraph, self).__init__(hg_index, ntypes, etypes)
else:
raise TypeError('Unrecognized graph data type %s' % type(graph_data))
# node and edge frame Parameters
if node_frames is None: ----------
self._node_frames = [ ntypes : list[str]
FrameRef(Frame(num_rows=self._graph.number_of_nodes(i))) The node types
for i in range(len(self._ntypes))]
else:
self._node_frames = node_frames
if edge_frames is None: Returns
self._edge_frames = [ -------
FrameRef(Frame(num_rows=self._graph.number_of_edges(i))) G : DGLHeteroGraph
for i in range(len(self._etypes))] The subgraph.
else: """
self._edge_frames = edge_frames rel_graphs = []
meta_edges = []
induced_etypes = []
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = []
# message indicators ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
self._msg_indices = [None] * len(self._etypes) srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
self._msg_frames = []
for i in range(len(self._etypes)): for i in range(len(self._etypes)):
frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i))) srctype = self._ntypes[srctype_id[i]]
frame.set_initializer(init.zero_initializer) dsttype = self._ntypes[dsttype_id[i]]
self._msg_frames.append(frame)
def _create_view(self, ntype_idx, etype_idx): if srctype in ntypes and dsttype in ntypes:
return DGLHeteroGraph( meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
graph_data=self, _view_ntype_idx=ntype_idx, _view_etype_idx=etype_idx) rel_graphs.append(self._graph.get_relation_graph(i))
induced_etypes.append(self.etypes[i])
edge_frames.append(self._edge_frames[i])
def _get_msg_index(self): metagraph = graph_index.from_edge_list(meta_edges, True, True)
if self._msg_indices[self._current_etype_idx] is None: hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
self._msg_indices[self._current_etype_idx] = utils.zero_index( hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
size=self._graph.number_of_edges(self._current_etype_idx)) return hg
return self._msg_indices[self._current_etype_idx]
def _set_msg_index(self, index): def edge_type_subgraph(self, etypes):
self._msg_indices[self._current_etype_idx] = index """Return the subgraph induced on given edge types.
def __getitem__(self, key): The metagraph of the returned subgraph is the subgraph of the original metagraph
if key in self._etypes_invmap: induced from the edge types.
return self._create_view(None, self._etypes_invmap[key])
else:
raise KeyError(key)
@property Features are shared with the original graph.
def _node_frame(self):
# overrides DGLGraph._node_frame
return self._node_frames[self._current_ntype_idx]
@property Examples
def _edge_frame(self): --------
# overrides DGLGraph._edge_frame TBD
return self._edge_frames[self._current_etype_idx]
@property Parameters
def _src_frame(self): ----------
# overrides DGLGraph._src_frame etypes : list[str or tuple]
return self._node_frames[self._current_srctype_idx] The edge types
@property Returns
def _dst_frame(self): -------
# overrides DGLGraph._dst_frame G : DGLHeteroGraph
return self._node_frames[self._current_dsttype_idx] The subgraph.
"""
etype_ids = [self.get_etype_id(etype) for etype in etypes]
meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids))
rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy()
induced_ntype_ids = list(set(meta_src) | set(meta_dst))
mapped_meta_src = [induced_ntype_ids[v] for v in meta_src]
mapped_meta_dst = [induced_ntype_ids[v] for v in meta_dst]
node_frames = [self._node_frames[i] for i in induced_ntype_ids]
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in induced_ntype_ids]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
def adjacency_matrix(self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None):
"""Return the adjacency matrix of edges of the given edge type.
@property By default, a row of returned adjacency matrix represents the
def _msg_frame(self): destination of an edge and the column represents the source.
# overrides DGLGraph._msg_frame
return self._msg_frames[self._current_etype_idx]
def add_nodes(self, node_type, num, data=None): When transpose is True, a row represents the source and a column
"""Add multiple new nodes of the same node type represents a destination.
Parameters Parameters
---------- ----------
node_type : str transpose : bool, optional (default=False)
Type of the added nodes. Must appear in the metagraph. A flag to transpose the returned adjacency matrix.
num : int ctx : context, optional (default=cpu)
Number of nodes to be added. The context of returned adjacency matrix.
data : dict, optional scipy_fmt : str, optional (default=None)
Feature data of the added nodes. If specified, return a scipy sparse matrix in the given format.
etype : str, optional
Examples The edge type. Can be omitted if there is only one edge type
-------- in the graph.
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['game'].number_of_nodes() Returns
2 -------
>>> g.add_nodes(3, 'game') # add 3 new games SparseTensor or scipy.sparse.spmatrix
>>> g['game'].number_of_nodes() Adjacency matrix.
5
""" """
pass etid = self.get_etype_id(etype)
if scipy_fmt is None:
return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
else:
return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
def add_edge(self, etype, u, v, data=None): # Alias of ``adjacency_matrix``
"""Add an edge of ``etype`` between u of the source node type, and v adj = adjacency_matrix
of the destination node type..
Parameters def incidence_matrix(self, typestr, ctx=F.cpu(), etype=None):
---------- """Return the incidence matrix representation of edges with the given
etype : (str, str, str) edge type.
The source-edge-destination type triplet
u : int
The source node ID of type ``utype``. Must exist in the graph.
v : int
The destination node ID of type ``vtype``. Must exist in the
graph.
data : dict, optional
Feature data of the added edge.
Examples An incidence matrix is an n x m sparse matrix, where n is
-------- the number of nodes and m is the number of edges. Each nnz
The variable ``g`` is constructed from the example in value indicating whether the edge is incident to the node
DGLBaseHeteroGraph. or not.
>>> g['plays'].number_of_edges() There are three types of an incidence matrix :math:`I`:
4
>>> g.add_edge(2, 0, 'plays')
>>> g['plays'].number_of_edges()
5
"""
pass
def add_edges(self, u, v, etype, data=None): * ``in``:
"""Add multiple edges of ``etype`` between list of source nodes ``u``
and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
Parameters - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
---------- (or :math:`v` is the dst node of :math:`e`);
u : list, tensor - :math:`I[v, e] = 0` otherwise.
The source node IDs of type ``utype``. Must exist in the graph.
v : list, tensor
The destination node IDs of type ``vtype``. Must exist in the
graph.
etype : (str, str, str)
The source-edge-destination type triplet
data : dict, optional
Feature data of the added edge.
Examples * ``out``:
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
>>> g['plays'].number_of_edges() - :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
4 (or :math:`v` is the src node of :math:`e`);
>>> g.add_edges([0, 2], [1, 0], 'plays') - :math:`I[v, e] = 0` otherwise.
>>> g['plays'].number_of_edges()
6 * ``both`` (only if source and destination node type are the same):
"""
pass - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
- :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
def from_networkx( - :math:`I[v, e] = 0` otherwise (including self-loop).
self,
nx_graph,
node_type_attr_name='type',
edge_type_attr_name='type',
node_id_attr_name='id',
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Convert from networkx graph.
The networkx graph must satisfy the metagraph. That is, for any
edge in the networkx graph, the source/destination node type must
be the same as the source/destination node of the edge type in
the metagraph. An error will be raised otherwise.
Parameters Parameters
---------- ----------
nx_graph : networkx.DiGraph typestr : str
The networkx graph. Can be either ``in``, ``out`` or ``both``
node_type_attr_name : str ctx : context, optional (default=cpu)
The node attribute name for the node type. The context of returned incidence matrix.
The attribute contents must be strings. etype : str, optional
edge_type_attr_name : str The edge type. Can be omitted if there is only one edge type
The edge attribute name for the edge type. in the graph.
The attribute contents must be strings.
node_id_attr_name : str Returns
The node attribute name for node type-specific IDs. -------
The attribute contents must be integers. SparseTensor
If the IDs of the same type are not consecutive integers, its The incidence matrix.
nodes will be relabeled using consecutive integers. The new
node ordering will inherit that of the sorted IDs.
edge_id_attr_name : str or None
The edge attribute name for edge type-specific IDs.
The attribute contents must be integers.
If the IDs of the same type are not consecutive integers, its
nodes will be relabeled using consecutive integers. The new
node ordering will inherit that of the sorted IDs.
If None is provided, the edge order would be arbitrary.
node_attrs : iterable of str, optional
The node attributes whose data would be copied.
edge_attrs : iterable of str, optional
The edge attributes whose data would be copied.
""" """
pass etid = self.get_etype_id(etype)
return self._graph.incidence_matrix(etid, typestr, ctx)[0]
# Alias of ``incidence_matrix``
inc = incidence_matrix
def node_attr_schemes(self, ntype): #################################################################
# Features
#################################################################
def node_attr_schemes(self, ntype=None):
"""Return the node feature schemes. """Return the node feature schemes.
Each feature scheme is a named tuple that stores the shape and data type Each feature scheme is a named tuple that stores the shape and data type
...@@ -1342,8 +1474,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1342,8 +1474,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
ntype : str ntype : str, optional
The node type The node type. Could be omitted if there is only one node
type in the graph. Error will be raised otherwise.
(Default: None)
Returns Returns
------- -------
...@@ -1354,13 +1488,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1354,13 +1488,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
-------- --------
The following uses PyTorch backend. The following uses PyTorch backend.
>>> g.ndata['user']['h'] = torch.randn(3, 4) >>> g.nodes['user'].data['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user') >>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)} {'h': Scheme(shape=(4,), dtype=torch.float32)}
""" """
return self._node_frames[self._ntypes_invmap[ntype]].schemes return self._node_frames[self.get_ntype_id(ntype)].schemes
def edge_attr_schemes(self, etype): def edge_attr_schemes(self, etype=None):
"""Return the edge feature schemes. """Return the edge feature schemes.
Each feature scheme is a named tuple that stores the shape and data type Each feature scheme is a named tuple that stores the shape and data type
...@@ -1368,8 +1502,9 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1368,8 +1502,9 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
etype : (str, str, str) etype : str or tuple of str, optional
The source-edge-destination type triplet The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
...@@ -1380,60 +1515,75 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1380,60 +1515,75 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
-------- --------
The following uses PyTorch backend. The following uses PyTorch backend.
>>> g.edata['user', 'plays', 'game']['h'] = torch.randn(4, 4) >>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
>>> g.edge_attr_schemes(('user', 'plays', 'game')) >>> g.edge_attr_schemes(('user', 'plays', 'game'))
{'h': Scheme(shape=(4,), dtype=torch.float32)} {'h': Scheme(shape=(4,), dtype=torch.float32)}
""" """
return self._edge_frames[self._etypes_invmap[etype]].schemes return self._edge_frames[self.get_etype_id(etype)].schemes
@property def set_n_initializer(self, initializer, field=None, ntype=None):
def nodes(self): """Set the initializer for empty node features.
"""Return a node view that can used to set/get feature data of a
single node type.
Examples Initializer is a callable that returns a tensor given the shape, data type
-------- and device context.
To set features of User #0 and #2 in a heterogeneous graph:
>>> g.nodes['user'][[0, 2]].data['h'] = torch.zeros(2, 5)
"""
return HeteroNodeView(self)
@property When a subset of the nodes are assigned a new feature, initializer is
def ndata(self): used to create feature for rest of the nodes.
"""Return the data view of all the nodes of a single node type.
Parameters
----------
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
ntype : str, optional
The node type. Could be omitted if there is only one node
type in the graph. Error will be raised otherwise.
(Default: None)
Examples Examples
-------- --------
To set features of games in a heterogeneous graph:
>>> g.ndata['game']['h'] = torch.zeros(2, 5)
"""
return HeteroNodeDataView(self)
@property Note
def edges(self): -----
"""Return an edges view that can used to set/get feature data of a User defined initializer must follow the signature of
single edge type. :func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Examples
--------
To set features of gameplays #1 (Bob -> Tetris) and #3 (Carol ->
Minecraft) in a heterogeneous graph:
>>> g.edges['user', 'plays', 'game'][[1, 3]].data['h'] = torch.zeros(2, 5)
""" """
return HeteroEdgeView(self) ntid = self.get_ntype_id(ntype)
self._node_frames[ntid].set_initializer(initializer, field)
@property def set_e_initializer(self, initializer, field=None, etype=None):
def edata(self): """Set the initializer for empty edge features.
"""Return the data view of all the edges of a single edge type.
Examples Initializer is a callable that returns a tensor given the shape, data
-------- type and device context.
>>> g.edata['developer', 'develops', 'game']['h'] = torch.zeros(2, 5)
When a subset of the edges are assigned a new feature, initializer is
used to create feature for rest of the edges.
Parameters
----------
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Note
-----
User defined initializer must follow the signature of
:func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
""" """
return HeteroEdgeDataView(self) etid = self.get_etype_id(etype)
self._edge_frames[etid].set_initializer(initializer, field)
def set_n_repr(self, ntype, data, u=ALL, inplace=False): def _set_n_repr(self, ntid, u, data, inplace=False):
"""Set node(s) representation of a single node type. """Internal API to set node features.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated, is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
...@@ -1445,18 +1595,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1445,18 +1595,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
ntype : str ntid : int
The node type Node type id.
data : dict of tensor
Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool data : dict of tensor
Node representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
(Default: False)
""" """
ntype = self._ntypes_invmap[ntype]
if is_all(u): if is_all(u):
num_nodes = self._graph.number_of_nodes(ntype) num_nodes = self._graph.number_of_nodes(ntid)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
...@@ -1468,19 +1618,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1468,19 +1618,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
if is_all(u): if is_all(u):
for key, val in data.items(): for key, val in data.items():
self._node_frames[ntype][key] = val self._node_frames[ntid][key] = val
else: else:
self._node_frames[ntype].update_rows(u, data, inplace=inplace) self._node_frames[ntid].update_rows(u, data, inplace=inplace)
def get_n_repr(self, ntype, u=ALL): def _get_n_repr(self, ntid, u):
"""Get node(s) representation of a single node type. """Get node(s) representation of a single node type.
The returned feature tensor batches multiple node features on the first dimension. The returned feature tensor batches multiple node features on the first dimension.
Parameters Parameters
---------- ----------
ntype : str ntid : int
The node type Node type id.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
...@@ -1489,22 +1639,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1489,22 +1639,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict dict
Representation dict from feature name to feature tensor. Representation dict from feature name to feature tensor.
""" """
if len(self.node_attr_schemes(ntype)) == 0:
return dict()
ntype_idx = self._ntypes_invmap[ntype]
if is_all(u): if is_all(u):
return dict(self._node_frames[ntype_idx]) return dict(self._node_frames[ntid])
else: else:
u = utils.toindex(u) u = utils.toindex(u)
return self._node_frames[ntype_idx].select_rows(u) return self._node_frames[ntid].select_rows(u)
def pop_n_repr(self, ntype, key): def _pop_n_repr(self, ntid, key):
"""Get and remove the specified node repr of a given node type. """Internal API to get and remove the specified node feature.
Parameters Parameters
---------- ----------
ntype : str ntid : int
The node type Node type id.
key : str key : str
The attribute name. The attribute name.
...@@ -1513,11 +1660,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1513,11 +1660,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor Tensor
The popped representation The popped representation
""" """
ntype = self._ntypes_invmap[ntype] return self._node_frames[ntid].pop(key)
return self._node_frames[ntype].pop(key)
def set_e_repr(self, etype, data, edges=ALL, inplace=False): def _set_e_repr(self, etid, edges, data, inplace=False):
"""Set edge(s) representation of a single edge type. """Internal API to set edge(s) features.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated, is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
...@@ -1528,10 +1674,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1528,10 +1674,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
etype : (str, str, str) etid : int
The source-edge-destination type triplet Edge type id.
data : tensor or dict of tensor
Edge representation.
edges : edges edges : edges
Edges can be either Edges can be either
...@@ -1540,10 +1684,12 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1540,10 +1684,12 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
* A tensor of edge ids of the given type. * A tensor of edge ids of the given type.
The default value is all the edges. The default value is all the edges.
inplace : bool data : tensor or dict of tensor
Edge representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
(Default: False)
""" """
etype_idx = self._etypes_invmap[etype]
# parse argument # parse argument
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
...@@ -1552,7 +1698,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1552,7 +1698,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v) _, _, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
...@@ -1562,7 +1708,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1562,7 +1708,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
' Got "%s" instead.' % type(data)) ' Got "%s" instead.' % type(data))
if is_all(eid): if is_all(eid):
num_edges = self._graph.number_of_edges(etype_idx) num_edges = self._graph.number_of_edges(etid)
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
...@@ -1575,18 +1721,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1575,18 +1721,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
if is_all(eid): if is_all(eid):
# update column # update column
for key, val in data.items(): for key, val in data.items():
self._edge_frames[etype_idx][key] = val self._edge_frames[etid][key] = val
else: else:
# update row # update row
self._edge_frames[etype_idx].update_rows(eid, data, inplace=inplace) self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
def get_e_repr(self, etype, edges=ALL): def _get_e_repr(self, etid, edges):
"""Get edge(s) representation. """Internal API to get edge features.
Parameters Parameters
---------- ----------
etype : (str, str, str) etid : int
The source-edge-destination type triplet Edge type id.
edges : edges edges : edges
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
...@@ -1596,9 +1742,6 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1596,9 +1742,6 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict dict
Representation dict Representation dict
""" """
etype_idx = self._etypes_invmap[etype]
if len(self.edge_attr_schemes(etype)) == 0:
return dict()
# parse argument # parse argument
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
...@@ -1607,23 +1750,23 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1607,23 +1750,23 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v) _, _, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
if is_all(eid): if is_all(eid):
return dict(self._edge_frames[etype_idx]) return dict(self._edge_frames[etid])
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
return self._edge_frames[etype_idx].select_rows(eid) return self._edge_frames[etid].select_rows(eid)
def pop_e_repr(self, etype, key): def _pop_e_repr(self, etid, key):
"""Get and remove the specified edge repr of a single edge type. """Get and remove the specified edge repr of a single edge type.
Parameters Parameters
---------- ----------
etype : (str, str, str) etid : int
The source-edge-destination type triplet Edge type id.
key : str key : str
The attribute name. The attribute name.
...@@ -1632,136 +1775,51 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1632,136 +1775,51 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor Tensor
The popped representation The popped representation
""" """
etype = self._etypes_invmap[etype] self._edge_frames[etid].pop(key)
self._edge_frames[etype].pop(key)
def register_message_func(self, func): #################################################################
"""Register global message function for each edge type provided. # Message passing
#################################################################
def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
"""Apply the function on the nodes with the same type to update their
features.
Once registered, ``func`` will be used as the default If None is provided for ``func``, nothing will happen.
message function in message passing operations, including
:func:`send`, :func:`send_and_recv`, :func:`pull`,
:func:`push`, :func:`update_all`.
Parameters Parameters
---------- ----------
func : callable func : callable
Message function on the edge. The function should be Apply function on the nodes. The function should be
an :mod:`Edge UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
v : int or iterable of int or tensor, optional
See Also The (type-specific) node (ids) on which to apply ``func``.
-------- ntype : str, optional
send The node type. Can be omitted if there is only one node type
send_and_recv in the graph.
pull inplace : bool, optional
push If True, update will be done in place, but autograd will break.
update_all
"""
raise NotImplementedError
def register_reduce_func(self, func):
"""Register global message reduce function for each edge type provided.
Once registered, ``func`` will be used as the default
message reduce function in message passing operations, including
:func:`recv`, :func:`send_and_recv`, :func:`push`, :func:`pull`,
:func:`update_all`.
Parameters
----------
func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
See Also
--------
recv
send_and_recv
push
pull
update_all
"""
raise NotImplementedError
def register_apply_node_func(self, func):
"""Register global node apply function for each node type provided.
Once registered, ``func`` will be used as the default apply
node function. Related operations include :func:`apply_nodes`,
:func:`recv`, :func:`send_and_recv`, :func:`push`, :func:`pull`,
:func:`update_all`.
Parameters
----------
func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
See Also
--------
apply_nodes
register_apply_edge_func
"""
raise NotImplementedError
def register_apply_edge_func(self, func):
"""Register global edge apply function for each edge type provided.
Once registered, ``func`` will be used as the default apply
edge function in :func:`apply_edges`.
Parameters
----------
func : callable
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
See Also
--------
apply_edges
register_apply_node_func
"""
raise NotImplementedError
def apply_nodes(self, func, v=ALL, inplace=False):
"""Apply the function on the nodes with the same type to update their
features.
If None is provided for ``func``, nothing will happen.
Parameters
----------
func : dict[str, callable] or None
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
v : dict[str, int or iterable of int or tensor], optional
The (type-specific) node (ids) on which to apply ``func``.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
Examples Examples
-------- --------
>>> g.ndata['user']['h'] = torch.ones(3, 5) >>> g.nodes['user'].data['h'] = torch.ones(3, 5)
>>> g.apply_nodes({'user': lambda nodes: {'h': nodes.data['h'] * 2}}) >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
>>> g.ndata['user']['h'] >>> g.nodes['user'].data['h']
tensor([[2., 2., 2., 2., 2.], tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]) [2., 2., 2., 2., 2.]])
""" """
for ntype, nfunc in func.items(): ntid = self.get_ntype_id(ntype)
if is_all(v): if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype))) v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
else: else:
v_ntype = utils.toindex(v[ntype]) v_ntype = utils.toindex(v)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes( scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
graph=self._create_view(self._ntypes_invmap[ntype], None), inplace=inplace)
v=v_ntype, Runtime.run(prog)
apply_func=nfunc,
inplace=inplace) def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, inplace=False):
"""Apply the function on the edges with the same type to update their """Apply the function on the edges with the same type to update their
features. features.
...@@ -1769,52 +1827,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1769,52 +1827,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
func : dict[(str, str, str), callable] or None func : callable or None
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
edges : dict[(str, str, str), any valid edge specification], optional edges : edges data, optional
Edges on which to apply ``func``. See :func:`send` for valid Edges on which to apply ``func``. See :func:`send` for valid
edge specification. edge specification.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
Examples Examples
-------- --------
>>> g.edata['user', 'plays', 'game']['h'] = torch.ones(4, 5) >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges( >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
... {('user', 'plays', 'game'): lambda edges: {'h': edges.data['h'] * 2}}) >>> g.edges[('user', 'plays', 'game')].data['h']
>>> g.edata['user', 'plays', 'game']['h']
tensor([[2., 2., 2., 2., 2.], tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]) [2., 2., 2., 2., 2.]])
""" """
for etype, efunc in func.items(): etid = self.get_etype_id(etype)
etype_idx = self._etypes_invmap[etype] stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etype_idx, 'eid') u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype))) eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid) u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_edges( scheduler.schedule_apply_edges(
graph=self._create_view(None, etype_idx), AdaptedHeteroGraph(self, stid, dtid, etid),
u=u, u, v, eid, func, inplace=inplace)
v=v, Runtime.run(prog)
eid=eid,
apply_func=efunc, def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
"""Group the edges by nodes and apply the function of the grouped """Group the edges by nodes and apply the function of the grouped
edges to update their features. The edges are of the same edge type edges to update their features. The edges are of the same edge type
(hence having the same source and destination node type). (hence having the same source and destination node type).
...@@ -1823,47 +1879,47 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1823,47 +1879,47 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
---------- ----------
group_by : str group_by : str
Specify how to group edges. Expected to be either 'src' or 'dst' Specify how to group edges. Expected to be either 'src' or 'dst'
func : dict[(str, str, str), callable] func : callable
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should
be (bucket_size, degrees, *feature_shape), and be (bucket_size, degrees, *feature_shape), and
return the dict with values of the same shapes. return the dict with values of the same shapes.
edges : dict[(str, str, str), valid edges type], optional edges : edges data, optional
Edges on which to group and apply ``func``. See :func:`send` for valid Edges on which to group and apply ``func``. See :func:`send` for valid
edges type. Default is all the edges. edges type. Default is all the edges.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
if group_by not in ('src', 'dst'): if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst") raise DGLError("Group_by should be either src or dst")
for etype, efunc in func.items(): etid = self.get_etype_id(etype)
etype_idx = self._etypes_invmap[etype] stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etype_idx) u, v, _ = self._graph.edges(etid)
eid = utils.toindex(slice(0, self.number_of_edges(etype))) eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid) u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_group_apply_edge( scheduler.schedule_group_apply_edge(
graph=self._create_view(None, etype_idx), AdaptedHeteroGraph(self, stid, dtid, etid),
u=u, u, v, eid,
v=v, func, group_by,
eid=eid, inplace=inplace)
apply_func=efunc, Runtime.run(prog)
group_by=group_by,
inplace=inplace) def send(self, edges, message_func, etype=None):
Runtime.run(prog)
def send(self, edges=ALL, message_func=None):
"""Send messages along the given edges with the same edge type. """Send messages along the given edges with the same edge type.
``edges`` can be any of the following types: ``edges`` can be any of the following types:
...@@ -1903,101 +1959,195 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -1903,101 +1959,195 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
along all edges between :math:`u` and :math:`v`. along all edges between :math:`u` and :math:`v`.
""" """
assert not utils.is_dict_like(message_func), \
"multiple-type message passing is not implemented"
assert message_func is not None assert message_func is not None
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(self._current_etype_idx))) eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
u, v, _ = self._graph.edges(self._current_etype_idx) u, v, _ = self._graph.edges(etid)
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid) u, v, _ = self._graph.find_edges(etid, eid)
if len(eid) == 0: if len(eid) == 0:
# no edge to be triggered # no edge to be triggered
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid, scheduler.schedule_send(
message_func=message_func) AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid,
message_func)
Runtime.run(prog) Runtime.run(prog)
def recv(self, def recv(self,
v=ALL, v,
reduce_func=None, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None,
inplace=False): inplace=False):
"""Receive and reduce incoming messages and update the features of node(s) :math:`v`. r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
Optionally, apply a function to update the node features after receive. It calculates:
.. math::
h_v^{new} = \sigma(\sum_{u\in\mathcal{N}_{t}(v)}m_{uv})
where :math:`\mathcal{N}_t(v)` defines the predecessors of node(s) ``v`` connected by
edge type :math:`t`, and :math:`m_{uv}` is the message on edge (u,v).
* ``reduce_func`` specifies :math:`\sum`.
* ``apply_func`` specifies :math:`\sigma`.
Other notes:
* `reduce_func` will be skipped for nodes with no incoming message. * `reduce_func` will be skipped for nodes with no incoming message.
* If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`. * If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
* If some ``v`` have no incoming message, their new feature value will be calculated * If some ``v`` have no incoming message, their new feature value will be calculated
by the column initializer (see :func:`set_n_initializer`). The feature shapes and by the column initializer (see :func:`set_n_initializer`). The feature shapes and
dtypes will be inferred. dtypes will be inferred.
* The node features will be updated by the result of the ``reduce_func``.
* Messages are consumed once received.
* The provided UDF maybe called multiple times so it is recommended to provide
function with no side effect.
* The cross-type reducer will check the output field of each per-type reducer
and aggregate those who write to the **same** fields. If None is provided,
the default behavior is overwrite.
The node features will be updated by the result of the ``reduce_func``. Examples
--------
Only one type of nodes in the graph:
Messages are consumed once received. >>> import dgl.function as fn
>>> G.recv(v, fn.sum('m', 'h'))
The provided UDF maybe called multiple times so it is recommended to provide Specify reducer for each type and use cross-type reducer to accum results.
function with no side effect.
Only works if the graph has one edge type. For multiple types, >>> import dgl.function as fn
use >>> G.recv(v,
>>> ... {'plays' : fn.sum('m', 'h'), 'develops' : fn.max('m', 'h')},
>>> ... 'sum')
.. code:: Error will be thrown if per-type reducers cannot determine the node type of v.
g['edgetype'].recv(v, reduce_func, apply_node_func, inplace) >>> import dgl.function as fn
>>> # ambiguous, v is of both 'user' and 'game' types
>>> G.recv(v,
>>> ... {('user', 'follows', 'user') : fn.sum('m', 'h'),
>>> ... ('user', 'plays', 'game') : fn.max('m', 'h')},
>>> ... 'sum')
Parameters Parameters
---------- ----------
v : int, container or tensor, optional v : int, container or tensor
The node(s) to be updated. Default is receiving all the nodes. The node(s) to be updated. Default is receiving all the nodes.
reduce_func : callable, optional reduce_func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable apply_node_func : callable
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
assert not utils.is_dict_like(reduce_func) and \ etid = self.get_etype_id(etype)
not utils.is_dict_like(apply_node_func), \ stid, dtid = self._graph.metagraph.find_edge(etid)
"multiple-type message passing is not implemented"
assert reduce_func is not None
if is_all(v): if is_all(v):
v = F.arange(0, self._graph.number_of_nodes(self._current_dsttype_idx)) v = F.arange(0, self.number_of_nodes(dtid))
elif isinstance(v, int): elif isinstance(v, int):
v = [v] v = [v]
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
# no vertex to be triggered. # no vertex to be triggered.
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_recv(graph=self, scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
recv_nodes=v, v, reduce_func, apply_node_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def multi_recv(self, v, reducer_dict, cross_reducer, apply_func=None, inplace=False):
r"""Receive messages from multiple edge types and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}m_{uv})
* ``per_type_reducer`` is a dictionary from edge type to reduce functions
:math:`\sum_{u\in\mathcal{N_t}(v)}` of each type.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
reduce_dict : dict of callable
Reduce function per edge type. The function should be
a :mod:`Node UDF <dgl.udf>`.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, reducer_dict)
ntid = self.get_ntype_id(ntype)
if is_all(v):
v = F.arange(0, self.number_of_nodes(ntid))
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
return
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
with ir.prog() as prog:
for ety, args in reducer_dict.items():
outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
args = pad_tuple(args, 2)
if args is None:
raise DGLError('Invalid per-type arguments. Should be either '
'(1) reduce_func or (2) (reduce_func, apply_func)')
rfunc, afunc = args
etid = self.get_etype_id(ety)
stid, dtid = self._graph.metagraph.find_edge(etid)
scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
v, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[ntid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, v, ntype, inplace)
def send_and_recv(self, def send_and_recv(self,
edges, edges,
message_func=None, message_func,
reduce_func=None, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None,
inplace=False): inplace=False):
"""Send messages along edges with the same edge type, and let destinations """Send messages along edges with the same edge type, and let destinations
receive them. receive them.
...@@ -2021,53 +2171,128 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2021,53 +2171,128 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
edges : valid edges type edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid Edges on which to apply ``func``. See :func:`send` for valid
edges type. edges type.
message_func : callable, optional message_func : callable
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional reduce_func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional apply_node_func : callable, optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
assert not utils.is_dict_like(message_func) and \ etid = self.get_etype_id(etype)
not utils.is_dict_like(reduce_func) and \ stid, dtid = self._graph.metagraph.find_edge(etid)
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
if isinstance(edges, tuple): if isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid) u, v, _ = self._graph.find_edges(etid, eid)
if len(u) == 0: if len(u) == 0:
# no edges to be triggered # no edges to be triggered
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_snr(graph=self, scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
edge_tuples=(u, v, eid), (u, v, eid),
message_func=message_func, message_func, reduce_func, apply_node_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def multi_send_and_recv(self, etype_dict, cross_reducer, apply_func=None, inplace=False):
r"""Send and receive messages along multiple edge types and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
h_u, h_v, h_{uv}))
* ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
normal ``send_and_recv``.
* :math:`\mathcal{N}_t(v)` is defined by the edges given for type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict of callable
``send_and_recv`` arguments per edge type.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
all_vs = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 4)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(edges, msg_func, reduce_func, [apply_func])')
edges, mfunc, rfunc, afunc = args
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
all_vs.append(v)
if len(u) == 0:
# no edges to be triggered
continue
scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
(u, v, eid),
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
self.apply_nodes(apply_func, dstnodes, ntype, inplace)
def pull(self, def pull(self,
v, v,
message_func=None, message_func,
reduce_func=None, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None,
inplace=False): inplace=False):
"""Pull messages from the node(s)' predecessors and then update their features. """Pull messages from the node(s)' predecessors and then update their features.
...@@ -2090,40 +2315,108 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2090,40 +2315,108 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
---------- ----------
v : int, container or tensor, optional v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes. The node(s) to be updated. Default is receiving all the nodes.
message_func : callable, optional message_func : callable
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional reduce_func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional apply_node_func : callable, optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
assert not utils.is_dict_like(message_func) and \ # only one type of edges
not utils.is_dict_like(reduce_func) and \ etid = self.get_etype_id(etype)
not utils.is_dict_like(apply_node_func), \ stid, dtid = self._graph.metagraph.find_edge(etid)
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_pull(graph=self, scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
pull_nodes=v, v,
message_func=message_func, message_func, reduce_func, apply_node_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def multi_pull(self, v, etype_dict, cross_reducer, apply_func=None, inplace=False):
r"""Pull and receive messages of the given nodes along multiple edge types
and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
h_u, h_v, h_{uv}))
* ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
normal ``pull``.
* :math:`\mathcal{N}_t(v)` is the set of predecessors of ``v`` connected by edge
type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict of callable
``pull`` arguments per edge type.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
v = utils.toindex(v)
if len(v) == 0:
return
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(msg_func, reduce_func, [apply_func])')
mfunc, rfunc, afunc = args
scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
v,
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, v, ntype, inplace)
def push(self, def push(self,
u, u,
message_func=None, message_func,
reduce_func=None, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None,
inplace=False): inplace=False):
"""Send message from the node(s) to their successors and update them. """Send message from the node(s) to their successors and update them.
...@@ -2140,41 +2433,40 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2140,41 +2433,40 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
---------- ----------
u : int, container or tensor u : int, container or tensor
The node(s) to push messages out. The node(s) to push messages out.
message_func : callable, optional message_func : callable
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional reduce_func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional apply_node_func : callable, optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
assert not utils.is_dict_like(message_func) and \ # only one type of edges
not utils.is_dict_like(reduce_func) and \ etid = self.get_etype_id(etype)
not utils.is_dict_like(apply_node_func), \ stid, dtid = self._graph.metagraph.find_edge(etid)
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
u = utils.toindex(u) u = utils.toindex(u)
if len(u) == 0: if len(u) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_push(graph=self, scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
u=u, u,
message_func=message_func, message_func, reduce_func, apply_node_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def update_all(self, def update_all(self,
message_func=None, message_func,
reduce_func=None, reduce_func,
apply_node_func=None): apply_node_func=None,
etype=None):
"""Send messages through all edges and update all nodes. """Send messages through all edges and update all nodes.
Optionally, apply a function to update the node features after receive. Optionally, apply a function to update the node features after receive.
...@@ -2192,229 +2484,235 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2192,229 +2484,235 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters Parameters
---------- ----------
message_func : callable, optional message_func : callable
Message function on the edges. The function should be Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional reduce_func : callable
Reduce function on the node. The function should be Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional apply_node_func : callable, optional
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
""" """
assert not utils.is_dict_like(message_func) and \ # only one type of edges
not utils.is_dict_like(reduce_func) and \ etid = self.get_etype_id(etype)
not utils.is_dict_like(apply_node_func), \ stid, dtid = self._graph.metagraph.find_edge(etid)
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_update_all(graph=self, scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
message_func=message_func, message_func, reduce_func,
reduce_func=reduce_func, apply_node_func)
apply_func=apply_node_func)
Runtime.run(prog) Runtime.run(prog)
def prop_nodes(self, def multi_update_all(self, etype_dict, cross_reducer, apply_func=None):
nodes_generator, r"""Send and receive messages along all edges.
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Node propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
def prop_edges(self,
edges_generator,
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Edge propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
def subgraph(self, nodes):
"""Return the subgraph induced on given nodes.
Parameters It calculates:
----------
nodes : dict[str, list or iterable]
A dictionary of node types to node ID array to construct
subgraph.
All nodes must exist in the graph.
Returns .. math::
------- h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
G : DGLHeteroSubGraph h_u, h_v, h_{uv}))
The subgraph.
The nodes are relabeled so that node `i` of type `t` in the
subgraph is mapped to the ``nodes[i]`` of type `t` in the
original graph.
The edges are also relabeled.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `parent_nid` and `parent_eid` properties of the
subgraph.
"""
pass
def subgraphs(self, nodes): * ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
"""Return a list of subgraphs, each induced in the corresponding given normal ``update_all``.
nodes in the list. * :math:`\mathcal{N}_t(v)` is the set of predecessors of ``v`` connected by edge
type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Equivalent to Examples
``[self.subgraph(nodes_list) for nodes_list in nodes]`` --------
TBD
Parameters Parameters
---------- ----------
nodes : a list of dict[str, list or iterable] v : int, container or tensor
A list of type-ID dictionaries to construct corresponding The node(s) to be updated.
subgraphs. The dictionaries are of the same form as etype_dict : dict of callable
:func:`subgraph`. ``update_all`` arguments per edge type.
All nodes in all the list items must exist in the graph. cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
Returns apply_node_func : callable
------- Apply function on the nodes. The function should be
G : A list of DGLHeteroSubGraph a :mod:`Node UDF <dgl.udf>`.
The subgraphs. inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
pass
def edge_subgraph(self, edges): # TODO(minjie): currently loop over each edge type and reuse the old schedule.
"""Return the subgraph induced on given edges. # Should replace it with fused kernel.
all_out = defaultdict(list)
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(msg_func, reduce_func, [apply_func])')
mfunc, rfunc, afunc = args
scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
mfunc, rfunc, afunc,
outframe=outframe)
all_out[dtid].append(outframe)
Runtime.run(prog)
for dtid, frames in all_out.items():
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(frames, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, ALL, self.ntypes[dtid], inplace=False)
def prop_nodes(self,
nodes_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Propagate messages using graph traversal by triggering
:func:`pull()` on nodes.
The traversal order is specified by the ``nodes_generator``. It generates
node frontiers, which is a list or a tensor of nodes. The nodes in the
same frontier will be triggered together, while nodes in different frontiers
will be triggered according to the generating order.
Parameters Parameters
---------- ----------
edges : dict[etype, list or iterable] node_generators : iterable, each element is a list or a tensor of node ids
A dictionary of edge types to edge ID array to construct The generator of node frontiers. It specifies which nodes perform
subgraph. :func:`pull` at each timestep.
All edges must exist in the subgraph. message_func : callable
The edge type is characterized by a triplet of source type name, Message function on the edges. The function should be
destination type name, and edge type name. an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns See Also
------- --------
G : DGLHeteroSubGraph prop_edges
The subgraph.
The edges are relabeled so that edge `i` of type `t` in the
subgraph is mapped to the ``edges[i]`` of type `t` in the
original graph.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `parent_nid` and `parent_eid` properties of the
subgraph.
""" """
pass for node_frontier in nodes_generator:
self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
def adjacency_matrix_scipy(self, etype, transpose=False, fmt='csr'):
"""Return the scipy adjacency matrix representation of edges with the
given edge type.
By default, a row of returned adjacency matrix represents the destination def prop_edges(self,
of an edge and the column represents the source. edges_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Propagate messages using graph traversal by triggering
:func:`send_and_recv()` on edges.
When transpose is True, a row represents the source and a column represents The traversal order is specified by the ``edges_generator``. It generates
a destination. edge frontiers. The edge frontiers should be of *valid edges type*.
See :func:`send` for more details.
The elements in the adajency matrix are edge ids. Edges in the same frontier will be triggered together, while edges in
different frontiers will be triggered according to the generating order.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str] edges_generator : generator
The edge type, characterized by a triplet of source type name, The generator of edge frontiers.
destination type name, and edge type name. message_func : callable
transpose : bool, optional (default=False) Message function on the edges. The function should be
A flag to transpose the returned adjacency matrix. an :mod:`Edge UDF <dgl.udf>`.
fmt : str, optional (default='csr') reduce_func : callable
Indicates the format of returned adjacency matrix. Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns See Also
------- --------
scipy.sparse.spmatrix prop_nodes
The scipy representation of adjacency matrix.
""" """
pass for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, message_func, reduce_func,
apply_node_func, etype=etype)
def adjacency_matrix(self, etype, transpose=False, ctx=F.cpu()): #################################################################
"""Return the adjacency matrix representation of edges with the # Misc
given edge type. #################################################################
By default, a row of returned adjacency matrix represents the def to_networkx(self, node_attrs=None, edge_attrs=None):
destination of an edge and the column represents the source. """Convert this graph to networkx graph.
When transpose is True, a row represents the source and a column The edge id will be saved as the 'id' edge attribute.
represents a destination.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str] node_attrs : iterable of str, optional
The edge type, characterized by a triplet of source type name, The node attributes to be copied.
destination type name, and edge type name. edge_attrs : iterable of str, optional
transpose : bool, optional (default=False) The edge attributes to be copied.
A flag to transpose the returned adjacency matrix.
ctx : context, optional (default=cpu)
The context of returned adjacency matrix.
Returns Returns
------- -------
SparseTensor networkx.DiGraph
The adjacency matrix. The nx graph
"""
pass
def incidence_matrix(self, etype, typestr, ctx=F.cpu()):
"""Return the incidence matrix representation of edges with the given
edge type.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix :math:`I`:
* ``in``:
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v` Examples
(or :math:`v` is the dst node of :math:`e`); --------
- :math:`I[v, e] = 0` otherwise.
* ``out``:
- :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
(or :math:`v` is the src node of :math:`e`);
- :math:`I[v, e] = 0` otherwise.
* ``both``:
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`; .. note:: Here we use pytorch syntax for demo. The general idea applies
- :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`; to other frameworks with minor syntax change (e.g. replace
- :math:`I[v, e] = 0` otherwise (including self-loop). ``torch.tensor`` with ``mxnet.ndarray``).
Parameters >>> import torch as th
---------- >>> g = DGLGraph()
etype : tuple[str, str, str] >>> g.add_nodes(5, {'n1': th.randn(5, 10)})
The edge type, characterized by a triplet of source type name, >>> g.add_edges([0,1,3,4], [2,4,0,3], {'e1': th.randn(4, 6)})
destination type name, and edge type name. >>> nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
typestr : str
Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu)
The context of returned incidence matrix.
Returns See Also
------- --------
SparseTensor dgl.to_networkx
The incidence matrix. """
""" # TODO(minjie): multi-type support
pass assert len(self.ntypes) == 1
assert len(self.etypes) == 1
def filter_nodes(self, ntype, predicate, nodes=ALL): src, dst = self.edges()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
nx_graph = nx.MultiDiGraph() if self.is_multigraph else nx.DiGraph()
nx_graph.add_nodes_from(range(self.number_of_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid)
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
feat_dict = self._get_n_repr(0, nid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True):
eid = attr['id']
feat_dict = self._get_e_repr(0, eid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
return nx_graph
def filter_nodes(self, predicate, nodes=ALL, ntype=None):
"""Return a tensor of node IDs with the given node type that satisfy """Return a tensor of node IDs with the given node type that satisfy
the given predicate. the given predicate.
Parameters Parameters
---------- ----------
ntype : str
The node type.
predicate : callable predicate : callable
A function of signature ``func(nodes) -> tensor``. A function of signature ``func(nodes) -> tensor``.
``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`. ``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`.
...@@ -2423,23 +2721,37 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2423,23 +2721,37 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
the batch satisfies the predicate. the batch satisfies the predicate.
nodes : int, iterable or tensor of ints nodes : int, iterable or tensor of ints
The nodes to filter on. Default value is all the nodes. The nodes to filter on. Default value is all the nodes.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns Returns
------- -------
tensor tensor
The filtered nodes. The nodes that satisfy the predicate.
""" """
pass ntid = self.get_ntype_id(ntype)
if is_all(nodes):
v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)))
else:
v = utils.toindex(nodes)
n_repr = self._get_n_repr(ntid, v)
nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu())
def filter_edges(self, etype, predicate, edges=ALL): if is_all(nodes):
return F.nonzero_1d(n_mask)
else:
nodes = F.tensor(nodes)
return F.boolean_mask(nodes, n_mask)
def filter_edges(self, predicate, edges=ALL, etype=None):
"""Return a tensor of edge IDs with the given edge type that satisfy """Return a tensor of edge IDs with the given edge type that satisfy
the given predicate. the given predicate.
Parameters Parameters
---------- ----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
predicate : callable predicate : callable
A function of signature ``func(edges) -> tensor``. A function of signature ``func(edges) -> tensor``.
``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`. ``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`.
...@@ -2449,114 +2761,456 @@ class DGLHeteroGraph(DGLBaseHeteroGraph): ...@@ -2449,114 +2761,456 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
edges : valid edges type edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid Edges on which to apply ``func``. See :func:`send` for valid
edges type. Default value is all the edges. edges type. Default value is all the edges.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
tensor tensor
The filtered edges represented by their ids. The edges that satisfy the predicate.
""" """
pass etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
def readonly(self, readonly_state=True): src_data = self._get_n_repr(stid, u)
"""Set this graph's readonly state in-place. edge_data = self._get_e_repr(etid, eid)
dst_data = self._get_n_repr(dtid, v)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges):
return F.nonzero_1d(e_mask)
else:
edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask)
def to(self, ctx): # pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Parameters Parameters
---------- ----------
readonly_state : bool, optional ctx : framework-specific context object
New readonly state of the graph, defaults to True. The context to move data to.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> G = dgl.DGLGraph()
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(torch.device('cuda:0'))
""" """
pass for i in range(len(self._node_frames)):
for k in self._node_frames[i].keys():
self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
for i in range(len(self._edge_frames)):
for k in self._edge_frames[i].keys():
self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
# TODO: replace this after implementing frame def local_var(self):
# pylint: disable=useless-super-delegation """Return a graph object that can be used in a local function scope.
def __repr__(self):
return super(DGLHeteroGraph, self).__repr__() The returned graph object shares the feature data and graph structure of this graph.
However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope.
If set, the local graph object will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> g = g.local_var()
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> g = g.local_var()
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
Notes
-----
Internally, the returned graph shares the same feature tensors, but construct a new
dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
graph will not reflect to the original graph. However, inplace operations do change
the shared tensor values, so will be reflected to the original graph. This function
also has little overhead when the number of feature tensors in this graph is small.
# pylint: disable=abstract-method See Also
class DGLHeteroSubGraph(DGLHeteroGraph): --------
local_var
Returns
-------
DGLGraph
The graph object that can be used as a local variable.
"""
local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(local_node_frames, self._node_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
local_node_frames,
local_edge_frames)
@contextmanager
def local_scope(self):
"""Enter a local scope context for this graph.
By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope.
If set, the local scope will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> with g.local_scope():
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> with g.local_scope():
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
See Also
--------
local_var
"""
old_nframes = self._node_frames
old_eframes = self._edge_frames
self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(self._node_frames, old_nframes):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(self._edge_frames, old_eframes):
sync_frame_initializer(fr1._frame, fr2._frame)
yield
self._node_frames = old_nframes
self._edge_frames = old_eframes
############################################################
# Internal APIs
############################################################
def make_canonical_etypes(etypes, ntypes, metagraph):
"""Internal function to convert etype name to (srctype, etype, dsttype)
Parameters
----------
etypes : list of str
Edge type list
ntypes : list of str
Node type list
metagraph : GraphIndex
Meta graph.
Returns
-------
list of tuples (srctype, etype, dsttype)
"""
# sanity check
if len(etypes) != metagraph.number_of_edges():
raise DGLError('Length of edge type list must match the number of '
'edges in the metagraph. {} vs {}'.format(
len(etypes), metagraph.number_of_edges()))
if len(ntypes) != metagraph.number_of_nodes():
raise DGLError('Length of nodes type list must match the number of '
'nodes in the metagraph. {} vs {}'.format(
len(ntypes), metagraph.number_of_nodes()))
src, dst, eid = metagraph.edges()
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
return rst
def infer_ntype_from_dict(graph, etype_dict):
"""Infer node type from dictionary of edge type to values.
All the edge types in the dict must share the same destination node type
and the node type will be returned. Otherwise, throw error.
Parameters
----------
graph : DGLHeteroGraph
Graph
etype_dict : dict
Dictionary whose key is edge type
Returns
-------
str
Node type
"""
ntype = None
for ety in etype_dict:
_, _, dty = graph.to_canonical_etype(ety)
if ntype is None:
ntype = dty
if ntype != dty:
raise DGLError("Cannot infer destination node type from the dictionary. "
"A valid specification must make sure that all the edge "
"type keys share the same destination node type.")
return ntype
def pad_tuple(tup, length, pad_val=None):
"""Pad the given tuple to the given length.
If the input is not a tuple, convert it to a tuple of length one.
Return None if pad fails.
"""
if not isinstance(tup, tuple):
tup = (tup, )
if len(tup) > length:
return None
elif len(tup) == length:
return tup
else:
return tup + (pad_val,) * (length - len(tup))
def merge_frames(frames, reducer):
"""Merge input frames into one. Resolve conflict fields using reducer.
Parameters
----------
frames : list of FrameRef
Input frames
reducer : str
One of "sum", "max", "min", "mean", "stack"
Returns
-------
FrameRef
Merged frame
""" """
if len(frames) == 1:
return frames[0]
if reducer == 'stack':
# TODO(minjie): Stack order does not matter. However, it must
# be consistent! Need to enforce one type of order.
def merger(flist):
flist = [F.unsqueeze(f, 1) for f in flist]
return F.stack(flist, 1)
else:
redfn = getattr(F, reducer, None)
if redfn is None:
raise DGLError('Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".')
def merger(flist):
return redfn(F.stack(flist, 0), 0)
ret = FrameRef(frame_like(frames[0]._frame))
keys = set()
for frm in frames:
keys.update(frm.keys())
for k in keys:
flist = []
for frm in frames:
if k in frm:
flist.append(frm[k])
if len(flist) > 1:
ret[k] = merger(flist)
else:
ret[k] = flist[0]
return ret
def combine_frames(frames, ids):
"""Merge the frames into one frame, taking the common columns.
Return None if there is no common columns.
Parameters Parameters
---------- ----------
parent : DGLHeteroGraph frames : List[FrameRef]
The parent graph. List of frames
parent_nid : dict[str, utils.Index] ids : List[int]
The type-specific parent node IDs for each type. List of frame IDs
parent_eid : dict[etype, utils.Index]
The type-specific parent edge IDs for each type. Returns
graph_idx : GraphIndex -------
The graph index FrameRef
shared : bool, optional The resulting frame
Whether the subgraph shares node/edge features with the parent graph
""" """
# pylint: disable=unused-argument, super-init-not-called # find common columns and check if their schemes match
def __init__( schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
self, for frame_id in ids:
parent, frame = frames[frame_id]
parent_nid, for key, scheme in list(schemes.items()):
parent_eid, if key in frame.schemes:
graph_idx, if frame.schemes[key] != scheme:
shared=False): raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
pass (key, frame.schemes[key], scheme))
else:
del schemes[key]
if len(schemes) == 0:
return None
# concatenate the columns
to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
return FrameRef(Frame(cols))
def combine_names(names, ids=None):
"""Combine the selected names into one new name.
Parameters
----------
names : list of str
String names
ids : numpy.ndarray, optional
Selected index
Returns
-------
str
"""
if ids is None:
return '+'.join(sorted(names))
else:
selected = sorted([names[i] for i in ids])
return '+'.join(selected)
class AdaptedHeteroGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLHeteroGraph
Graph
stid : int
Source node type id
dtid : int
Destination node type id
etid : int
Edge type id
"""
def __init__(self, graph, stid, dtid, etid):
self.graph = graph
self.stid = stid
self.dtid = dtid
self.etid = etid
@property @property
def parent_nid(self): def gidx(self):
"""Get the parent node ids. return self.graph._graph
The returned tensor dictionary can be used as a map from the node id def num_src(self):
in this subgraph to the node id in the parent graph. """Number of source nodes."""
return self.graph._graph.number_of_nodes(self.stid)
Returns def num_dst(self):
------- """Number of destination nodes."""
dict[str, Tensor] return self.graph._graph.number_of_nodes(self.dtid)
The parent node id array for each type.
""" def num_edges(self):
pass """Number of edges."""
return self.graph._graph.number_of_edges(self.etid)
@property @property
def parent_eid(self): def srcframe(self):
"""Get the parent edge ids. """Frame to store source node features."""
return self.graph._node_frames[self.stid]
The returned tensor dictionary can be used as a map from the edge id @property
in this subgraph to the edge id in the parent graph. def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frames[self.dtid]
Returns @property
------- def edgeframe(self):
dict[etype, Tensor] """Frame to store edge features."""
The parent edge id array for each type. return self.graph._edge_frames[self.etid]
The edge types are characterized by a triplet of source type
name, destination type name, and edge type name.
"""
pass
def copy_to_parent(self, inplace=False): @property
"""Write node/edge features to the parent graph. def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frames[self.etid]
Parameters @property
---------- def msgindicator(self):
inplace : bool """Message indicator tensor."""
If true, use inplace write (no gradient but faster) return self.graph._get_msg_index(self.etid)
"""
pass
def copy_from_parent(self): @msgindicator.setter
"""Copy node/edge features from the parent graph. def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(self.etid, val)
All old features will be removed. def in_edges(self, nodes):
""" return self.graph._graph.in_edges(self.etid, nodes)
pass
def map_to_subgraph_nid(self, parent_vids): def out_edges(self, nodes):
"""Map the node IDs in the parent graph to the node IDs in the return self.graph._graph.out_edges(self.etid, nodes)
subgraph.
Parameters def edges(self, form):
---------- return self.graph._graph.edges(self.etid, form)
parent_vids : dict[str, list or tensor]
The dictionary of node types to parent node ID array.
Returns def get_immutable_gidx(self, ctx):
------- return self.graph._graph.get_unitgraph(self.etid, ctx)
dict[str, tensor]
The node ID array in the subgraph of each node type. def bits_needed(self):
""" return self.graph._graph.bits_needed(self.etid)
pass
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import scipy
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
...@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase):
return self.metagraph.number_of_edges() return self.metagraph.number_of_edges()
def get_relation_graph(self, etype): def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type. """Get the unitgraph graph of the given edge/relation type.
Parameters Parameters
---------- ----------
...@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase): ...@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase):
Returns Returns
------- -------
HeteroGraphIndex HeteroGraphIndex
The bipartite graph. The unitgraph graph.
""" """
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype)) return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def flatten_relations(self, etypes):
"""Convert the list of requested unitgraph graphs into a single unitgraph
graph.
Parameters
----------
etypes : list[int]
The edge/relation types.
Returns
-------
HeteroGraphIndex
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
def add_nodes(self, ntype, num): def add_nodes(self, ntype, num):
"""Add nodes. """Add nodes.
...@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase):
return _CAPI_DGLHeteroNumBits(self) return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype): def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph. """Return the number of integer bits needed to represent the unitgraph graph.
Parameters Parameters
---------- ----------
...@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase): ...@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase):
else: else:
raise Exception("unknown format") raise Exception("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
fmt : str
Indicates the format of returned adjacency matrix.
return_edge_ids : bool
Indicates whether to return edge IDs or 1 as elements.
Returns
-------
scipy.sparse.spmatrix
The scipy representation of adjacency matrix.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
if return_edge_ids is None:
dgl_warning(
"Adjacency matrix by default currently returns edge IDs."
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
return_edge_ids = True
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy()
row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
else:
raise Exception("unknown format")
def incidence_matrix(self, etype, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix `I`:
* "in":
- I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);
- I[v, e] = 0 otherwise.
* "out":
- I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);
- I[v, e] = 0 otherwise.
* "both":
- I[v, e] = 1 if e is the in-edge of v;
- I[v, e] = -1 if e is the out-edge of v;
- I[v, e] = 0 otherwise (including self-loop).
Parameters
----------
etype : int
Edge type
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
Returns
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
src, dst, eid = self.edges(etype)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
srctype, dsttype = self.metagraph.find_edge(etype)
m = self.number_of_edges(etype)
if typestr == 'in':
n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'out':
n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'both':
assert srctype == dsttype, \
"'both' is supported only if source and destination type are the same"
n = self.number_of_nodes(srctype)
# first remove entries for self loops
mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask)
dst = F.boolean_mask(dst, mask)
eid = F.boolean_mask(eid, mask)
n_entries = F.shape(src)[0]
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def node_subgraph(self, induced_nodes): def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph. """Return the induced node subgraph.
...@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase): ...@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase):
eids = [edges.todgltensor() for edges in induced_edges] eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes) return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite') @utils.cached_member(cache='_cache', prefix='unitgraph')
def get_bipartite(self, etype, ctx): def get_unitgraph(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device """Create a unitgraph graph from given edge type and copy to the given device
context. context.
Note: this internal function is for DGL scheduler use only Note: this internal function is for DGL scheduler use only
Parameters Parameters
---------- ----------
etype : int, or None etype : int
If the graph index is a Bipartite graph index, this argument must be None. If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type. Otherwise, it represents the edge type.
ctx : DGLContext ctx : DGLContext
...@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
g = self.get_relation_graph(etype) if etype is not None else self g = self.get_relation_graph(etype)
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx) return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype): def get_csr_shuffle_order(self, etype):
...@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col): #################################################################
"""Create a bipartite graph index from COO format # Creators
#################################################################
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col):
"""Create a unitgraph graph index from COO format
Parameters Parameters
---------- ----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int num_src : int
Number of nodes in the src type. Number of nodes in the src type.
num_dst : int num_dst : int
...@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col): ...@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
return _CAPI_DGLHeteroCreateBipartiteFromCOO( return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor()) int(num_ntypes), int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids): def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format """Create a unitgraph graph index from CSR format
Parameters Parameters
---------- ----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int num_src : int
Number of nodes in the src type. Number of nodes in the src type.
num_dst : int num_dst : int
...@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids): ...@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
return _CAPI_DGLHeteroCreateBipartiteFromCSR( return _CAPI_DGLHeteroCreateUnitGraphFromCSR(
int(num_src), int(num_dst), int(num_ntypes), int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor()) indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs): def create_heterograph_from_relations(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation. """Create a heterograph from metagraph and graphs of every relation.
Parameters Parameters
......
...@@ -3,3 +3,4 @@ from __future__ import absolute_import ...@@ -3,3 +3,4 @@ from __future__ import absolute_import
from . import scheduler from . import scheduler
from .runtime import Runtime from .runtime import Runtime
from .adapter import GraphAdapter
"""Temporary adapter to unify DGLGraph and HeteroGraph for scheduler.
NOTE(minjie): remove once all scheduler codes are migrated to heterograph
"""
from __future__ import absolute_import
from abc import ABC, abstractmethod
class GraphAdapter(ABC):
"""Temporary adapter class to unify DGLGraph and DGLHeteroGraph for schedulers."""
@property
@abstractmethod
def gidx(self):
"""Get graph index object."""
@abstractmethod
def num_src(self):
"""Number of source nodes."""
@abstractmethod
def num_dst(self):
"""Number of destination nodes."""
@abstractmethod
def num_edges(self):
"""Number of edges."""
@property
@abstractmethod
def srcframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def dstframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def edgeframe(self):
"""Frame to store edge features."""
@property
@abstractmethod
def msgframe(self):
"""Frame to store messages."""
@property
@abstractmethod
def msgindicator(self):
"""Message indicator tensor."""
@msgindicator.setter
@abstractmethod
def msgindicator(self, val):
"""Set new message indicator tensor."""
@abstractmethod
def in_edges(self, nodes):
"""Get in edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def out_edges(self, nodes):
"""Get out edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def edges(self, form):
"""Get all edges
Parameters
----------
form : str
"eid", "uv", etc.
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def get_immutable_gidx(self, ctx):
"""Get immutable graph index for kernel computation.
Parameters
----------
ctx : DGLContext
The context of the returned graph.
Returns
-------
GraphIndex
"""
@abstractmethod
def bits_needed(self):
"""Return the number of integer bits needed to represent the graph
Returns
-------
int
The number of bits needed
"""
...@@ -10,7 +10,6 @@ from . import ir ...@@ -10,7 +10,6 @@ from . import ir
from .ir import var from .ir import var
def gen_degree_bucketing_schedule( def gen_degree_bucketing_schedule(
graph,
reduce_udf, reduce_udf,
message_ids, message_ids,
dst_nodes, dst_nodes,
...@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule( ...@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule(
Parameters Parameters
---------- ----------
graph : DGLGraph
DGLGraph to use
reduce_udf : callable reduce_udf : callable
The UDF to reduce messages. The UDF to reduce messages.
message_ids : utils.Index message_ids : utils.Index
...@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule( ...@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule(
fd_list = [] fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids): for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc # create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt) rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt)
# vars # vars
vbkt = var.IDX(vbkt) vbkt = var.IDX(vbkt)
mid = var.IDX(mid) mid = var.IDX(mid)
...@@ -144,7 +141,7 @@ def _process_node_buckets(buckets): ...@@ -144,7 +141,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt): def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF.""" """Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data): def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key): def _reshaped_getter(key):
...@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt): ...@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:] new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys()) reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data) nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data)
return reduce_udf(nbatch) return reduce_udf(nbatch)
return _rfunc_wrapper return _rfunc_wrapper
def gen_group_apply_edge_schedule( def gen_group_apply_edge_schedule(
graph,
apply_func, apply_func,
u, v, eid, u, v, eid,
group_by, group_by,
...@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule( ...@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule(
Parameters Parameters
---------- ----------
graph : DGLGraph
DGLGraph to use
apply_func: callable apply_func: callable
The edge_apply_func UDF The edge_apply_func UDF
u: utils.Index u: utils.Index
...@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule( ...@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule(
fd_list = [] fd_list = []
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids): for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc # create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(graph, apply_func, deg, _efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt)) u_bkt, v_bkt, eid_bkt))
# vars # vars
var_u = var.IDX(u_bkt) var_u = var.IDX(u_bkt)
...@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets): ...@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets):
eids = split(eids) eids = split(eids)
return degs, uids, vids, eids return degs, uids, vids, eids
def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid): def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
"""Internal function to generate the per degree bucket edge UDF.""" """Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data): def _efunc_wrapper(src_data, edge_data, dst_data):
...@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid): ...@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
edge_data.keys()) edge_data.keys())
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data), reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys()) dst_data.keys())
ebatch = EdgeBatch(graph, (u, v, eid), reshaped_src_data, ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data) reshaped_edge_data, reshaped_dst_data)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()} return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper return _efunc_wrapper
......
...@@ -8,8 +8,6 @@ from .. import backend as F ...@@ -8,8 +8,6 @@ from .. import backend as F
from ..frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -30,22 +28,16 @@ __all__ = [ ...@@ -30,22 +28,16 @@ __all__ = [
"schedule_pull" "schedule_pull"
] ]
def _dispatch(graph, method, *args, **kwargs): def schedule_send(graph,
graph_index = graph._graph u, v, eid,
if isinstance(graph_index, GraphIndex): message_func,
return getattr(graph._graph, method)(*args, **kwargs) msgframe=None):
elif isinstance(graph_index, HeteroGraphIndex): """Schedule send
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func):
"""get send schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
u : utils.Index u : utils.Index
Source nodes Source nodes
v : utils.Index v : utils.Index
...@@ -54,11 +46,13 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -54,11 +46,13 @@ def schedule_send(graph, u, v, eid, message_func):
Ids of sending edges Ids of sending edges
message_func: callable or list of callable message_func: callable or list of callable
The message function The message function
msgframe : FrameRef, optional
The storage to write messages to. If None, use graph.msgframe.
""" """
var_mf = var.FEAT_DICT(graph._msg_frame) var_mf = var.FEAT_DICT(msgframe if msgframe is not None else graph.msgframe)
var_src_nf = var.FEAT_DICT(graph._src_frame) var_src_nf = var.FEAT_DICT(graph.srcframe)
var_dst_nf = var.FEAT_DICT(graph._dst_frame) var_dst_nf = var.FEAT_DICT(graph.dstframe)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph.edgeframe)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
var_msg = _gen_send(graph=graph, var_msg = _gen_send(graph=graph,
...@@ -73,19 +67,20 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -73,19 +67,20 @@ def schedule_send(graph, u, v, eid, message_func):
# write tmp msg back # write tmp msg back
ir.WRITE_ROW_(var_mf, var_eid, var_msg) ir.WRITE_ROW_(var_mf, var_eid, var_msg)
# set message indicator to 1 # set message indicator to 1
graph._set_msg_index(graph._get_msg_index().set_items(eid, 1)) graph.msgindicator = graph.msgindicator.set_items(eid, 1)
def schedule_recv(graph, def schedule_recv(graph,
recv_nodes, recv_nodes,
reduce_func, reduce_func,
apply_func, apply_func,
inplace): inplace,
outframe=None):
"""Schedule recv. """Schedule recv.
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
recv_nodes: utils.Index recv_nodes: utils.Index
Nodes to recv. Nodes to recv.
reduce_func: callable or list of callable reduce_func: callable or list of callable
...@@ -94,10 +89,12 @@ def schedule_recv(graph, ...@@ -94,10 +89,12 @@ def schedule_recv(graph,
The apply node function The apply node function
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
""" """
src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes) src, dst, eid = graph.in_edges(recv_nodes)
if len(eid) > 0: if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero() nonzero_idx = graph.msgindicator.get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx) eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx) src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx) dst = dst.get_items(nonzero_idx)
...@@ -106,9 +103,11 @@ def schedule_recv(graph, ...@@ -106,9 +103,11 @@ def schedule_recv(graph,
# 1) all recv nodes are 0-degree nodes # 1) all recv nodes are 0-degree nodes
# 2) no send has been called # 2) no send has been called
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace) schedule_apply_nodes(recv_nodes, apply_func, graph.dstframe,
inplace, outframe)
else: else:
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
# sort and unique the argument # sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
...@@ -117,23 +116,24 @@ def schedule_recv(graph, ...@@ -117,23 +116,24 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid), reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes) recv_nodes)
# apply # apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_out_nf, var_recv_nodes, final_feat)
# set message indicator to 0 # set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0)) graph.msgindicator = graph.msgindicator.set_items(eid, 0)
if not graph._get_msg_index().has_nonzero(): if not graph.msgindicator.has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf')) ir.CLEAR_FRAME_(var.FEAT_DICT(graph.msgframe, name='mf'))
def schedule_snr(graph, def schedule_snr(graph,
edge_tuples, edge_tuples,
message_func, message_func,
reduce_func, reduce_func,
apply_func, apply_func,
inplace): inplace,
outframe=None):
"""Schedule send_and_recv. """Schedule send_and_recv.
Currently it builds a subgraph from edge_tuples with the same number of Currently it builds a subgraph from edge_tuples with the same number of
...@@ -142,8 +142,8 @@ def schedule_snr(graph, ...@@ -142,8 +142,8 @@ def schedule_snr(graph,
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
edge_tuples: tuple edge_tuples: tuple
A tuple of (src ids, dst ids, edge ids) representing edges to perform A tuple of (src ids, dst ids, edge ids) representing edges to perform
send_and_recv send_and_recv
...@@ -155,12 +155,15 @@ def schedule_snr(graph, ...@@ -155,12 +155,15 @@ def schedule_snr(graph,
The apply node function The apply node function
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
...@@ -168,12 +171,11 @@ def schedule_snr(graph, ...@@ -168,12 +171,11 @@ def schedule_snr(graph,
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv( adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes()) edge_tuples, graph.num_src(), graph.num_dst())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(src_node_frame=graph.srcframe,
src_node_frame=graph._src_frame, dst_node_frame=graph.dstframe,
dst_node_frame=graph._dst_frame, edge_frame=graph.edgeframe,
edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
var_send_edges=var_eid, var_send_edges=var_eid,
...@@ -182,52 +184,56 @@ def schedule_snr(graph, ...@@ -182,52 +184,56 @@ def schedule_snr(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, reduced_feat,
apply_func) apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_out_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph, def schedule_update_all(graph,
message_func, message_func,
reduce_func, reduce_func,
apply_func): apply_func,
"""get send and recv schedule outframe=None):
"""Get send and recv schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
message_func: callable or list of callable message_func: callable or list of callable
The message function The message function
reduce_func: callable or list of callable reduce_func: callable or list of callable
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
""" """
if graph._number_of_edges() == 0: if graph.num_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes # All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) nodes = utils.toindex(slice(0, graph.num_dst()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False) schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe)
else: else:
eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL eid = utils.toindex(slice(0, graph.num_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send + reduce # generate send + reduce
def uv_getter(): def uv_getter():
src, dst, _ = _dispatch(graph, 'edges', 'eid') src, dst, _ = graph.edges('eid')
return var.IDX(src), var.IDX(dst) return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph) adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(src_node_frame=graph.srcframe,
src_node_frame=graph._src_frame, dst_node_frame=graph.dstframe,
dst_node_frame=graph._dst_frame, edge_frame=graph.edgeframe,
edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
var_send_edges=var_eid, var_send_edges=var_eid,
...@@ -236,50 +242,54 @@ def schedule_update_all(graph, ...@@ -236,50 +242,54 @@ def schedule_update_all(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
ir.WRITE_DICT_(var_dst_nf, final_feat) ir.WRITE_DICT_(var_out_nf, final_feat)
def schedule_apply_nodes(graph, def schedule_apply_nodes(v,
v,
apply_func, apply_func,
inplace): node_frame,
"""get apply nodes schedule inplace,
outframe=None):
"""Get apply nodes schedule
Parameters Parameters
---------- ----------
graph: DGLGraph
The DGLGraph to use
v : utils.Index v : utils.Index
Nodes to apply Nodes to apply
apply_func: callable apply_func : callable
The apply node function The apply node function
node_frame : FrameRef
Node feature frame.
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame.
Returns Returns
------- -------
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
var_v = var.IDX(v) var_v = var.IDX(v)
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_nf = var.FEAT_DICT(node_frame, name='nf')
var_out_nf = var_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
v_nf = ir.READ_ROW(var_nf, var_v) v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, v, node_data) nbatch = NodeBatch(v, node_data)
return apply_func(nbatch) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_v, applied_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_v, applied_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat) ir.WRITE_ROW_(var_out_nf, var_v, applied_feat)
def schedule_nodeflow_apply_nodes(graph, def schedule_nodeflow_apply_nodes(graph,
layer_id, layer_id,
v, v,
apply_func, apply_func,
inplace): inplace):
"""get apply nodes schedule in NodeFlow. """Get apply nodes schedule in NodeFlow.
Parameters Parameters
---------- ----------
...@@ -302,7 +312,7 @@ def schedule_nodeflow_apply_nodes(graph, ...@@ -302,7 +312,7 @@ def schedule_nodeflow_apply_nodes(graph,
var_v = var.IDX(v) var_v = var.IDX(v)
v_nf = ir.READ_ROW(var_nf, var_v) v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, v, node_data) nbatch = NodeBatch(v, node_data)
return apply_func(nbatch) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
...@@ -315,13 +325,14 @@ def schedule_nodeflow_apply_nodes(graph, ...@@ -315,13 +325,14 @@ def schedule_nodeflow_apply_nodes(graph,
def schedule_apply_edges(graph, def schedule_apply_edges(graph,
u, v, eid, u, v, eid,
apply_func, apply_func,
inplace): inplace,
"""get apply edges schedule outframe=None):
"""Get apply edges schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
u : utils.Index u : utils.Index
Source nodes of edges to apply Source nodes of edges to apply
v : utils.Index v : utils.Index
...@@ -332,23 +343,25 @@ def schedule_apply_edges(graph, ...@@ -332,23 +343,25 @@ def schedule_apply_edges(graph,
The apply edge function The apply edge function
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.edge_frame.
Returns Returns
------- -------
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
# vars # vars
var_src_nf = var.FEAT_DICT(graph._src_frame) var_src_nf = var.FEAT_DICT(graph.srcframe, 'uframe')
var_dst_nf = var.FEAT_DICT(graph._dst_frame) var_dst_nf = var.FEAT_DICT(graph.dstframe, 'vframe')
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph.edgeframe, 'eframe')
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, 'out_ef')
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func, var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_src_nf, var_dst_nf=var_dst_nf, var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef) var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# schedule apply edges # schedule apply edges
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
else: else:
ir.WRITE_ROW_(var_ef, var_eid, var_out) ir.WRITE_ROW_(var_ef, var_eid, var_out)
...@@ -356,7 +369,7 @@ def schedule_nodeflow_apply_edges(graph, block_id, ...@@ -356,7 +369,7 @@ def schedule_nodeflow_apply_edges(graph, block_id,
u, v, eid, u, v, eid,
apply_func, apply_func,
inplace): inplace):
"""get apply edges schedule in NodeFlow. """Get apply edges schedule in NodeFlow.
Parameters Parameters
---------- ----------
...@@ -397,13 +410,14 @@ def schedule_push(graph, ...@@ -397,13 +410,14 @@ def schedule_push(graph,
message_func, message_func,
reduce_func, reduce_func,
apply_func, apply_func,
inplace): inplace,
"""get push schedule outframe=None):
"""Get push schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
u : utils.Index u : utils.Index
Source nodes for push Source nodes for push
message_func: callable or list of callable message_func: callable or list of callable
...@@ -414,26 +428,30 @@ def schedule_push(graph, ...@@ -414,26 +428,30 @@ def schedule_push(graph,
The apply node function The apply node function
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
""" """
u, v, eid = _dispatch(graph, 'out_edges', u) u, v, eid = graph.out_edges(u)
if len(eid) == 0: if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled. # All the pushing nodes have no out edges. No computation is scheduled.
return return
schedule_snr(graph, (u, v, eid), schedule_snr(graph, (u, v, eid),
message_func, reduce_func, apply_func, inplace) message_func, reduce_func, apply_func,
inplace, outframe)
def schedule_pull(graph, def schedule_pull(graph,
pull_nodes, pull_nodes,
message_func, message_func,
reduce_func, reduce_func,
apply_func, apply_func,
inplace): inplace,
"""get pull schedule outframe=None):
"""Get pull schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
pull_nodes : utils.Index pull_nodes : utils.Index
Destination nodes for pull Destination nodes for pull
message_func: callable or list of callable message_func: callable or list of callable
...@@ -444,20 +462,23 @@ def schedule_pull(graph, ...@@ -444,20 +462,23 @@ def schedule_pull(graph,
The apply node function The apply node function
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
""" """
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat # can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier. # directly from pull node frontier.
u, v, eid = _dispatch(graph, 'in_edges', pull_nodes) u, v, eid = graph.in_edges(pull_nodes)
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func, inplace) schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace, outframe)
else: else:
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes)
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes') var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
...@@ -465,31 +486,33 @@ def schedule_pull(graph, ...@@ -465,31 +486,33 @@ def schedule_pull(graph,
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv( adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes()) (u, v, eid), graph.num_src(), graph.num_dst())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._src_frame, reduced_feat = _gen_send_reduce(graph.srcframe,
graph._dst_frame, graph._edge_frame, graph.dstframe, graph.edgeframe,
message_func, reduce_func, var_eid, message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator, var_pull_nodes, uv_getter, adj_creator,
out_map_creator) out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(var_pull_nodes, var_dst_nf,
reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_pull_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_(var_out_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph, def schedule_group_apply_edge(graph,
u, v, eid, u, v, eid,
apply_func, apply_func,
group_by, group_by,
inplace): inplace,
"""group apply edges schedule outframe=None):
"""Group apply edges schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: GraphAdaptor
The DGLGraph to use Graph
u : utils.Index u : utils.Index
Source nodes of edges to apply Source nodes of edges to apply
v : utils.Index v : utils.Index
...@@ -502,23 +525,22 @@ def schedule_group_apply_edge(graph, ...@@ -502,23 +525,22 @@ def schedule_group_apply_edge(graph,
Specify how to group edges. Expected to be either 'src' or 'dst' Specify how to group edges. Expected to be either 'src' or 'dst'
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional
Returns The storage to write output data. If None, use graph.edgeframe.
-------
A list of executors for DGL Runtime
""" """
# vars # vars
var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf') var_src_nf = var.FEAT_DICT(graph.srcframe, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph.edgeframe, name='ef')
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, name='out_ef')
var_out = var.FEAT_DICT(name='new_ef') var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by, db.gen_group_apply_edge_schedule(apply_func, u, v, eid, group_by,
var_src_nf, var_dst_nf, var_ef, var_out) var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
else: else:
ir.WRITE_ROW_(var_ef, var_eid, var_out) ir.WRITE_ROW_(var_out_ef, var_eid, var_out)
def schedule_nodeflow_update_all(graph, def schedule_nodeflow_update_all(graph,
...@@ -526,7 +548,7 @@ def schedule_nodeflow_update_all(graph, ...@@ -526,7 +548,7 @@ def schedule_nodeflow_update_all(graph,
message_func, message_func,
reduce_func, reduce_func,
apply_func): apply_func):
"""get update_all schedule in a block. """Get update_all schedule in a block.
Parameters Parameters
---------- ----------
...@@ -555,8 +577,7 @@ def schedule_nodeflow_update_all(graph, ...@@ -555,8 +577,7 @@ def schedule_nodeflow_update_all(graph,
return var.IDX(utils.toindex(src)), var.IDX(utils.toindex(dst)) return var.IDX(utils.toindex(src)), var.IDX(utils.toindex(dst))
adj_creator = lambda: spmv.build_gidx_and_mapping_block(graph, block_id) adj_creator = lambda: spmv.build_gidx_and_mapping_block(graph, block_id)
out_map_creator = lambda nbits: None out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(src_node_frame=graph._get_node_frame(block_id),
src_node_frame=graph._get_node_frame(block_id),
dst_node_frame=graph._get_node_frame(block_id + 1), dst_node_frame=graph._get_node_frame(block_id + 1),
edge_frame=graph._get_edge_frame(block_id), edge_frame=graph._get_edge_frame(block_id),
message_func=message_func, message_func=message_func,
...@@ -567,7 +588,7 @@ def schedule_nodeflow_update_all(graph, ...@@ -567,7 +588,7 @@ def schedule_nodeflow_update_all(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_dest_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(var_dest_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_nf, final_feat)
...@@ -626,8 +647,7 @@ def schedule_nodeflow_compute(graph, ...@@ -626,8 +647,7 @@ def schedule_nodeflow_compute(graph,
graph, block_id, (u, v, eid)) graph, block_id, (u, v, eid))
out_map_creator = lambda nbits: _build_idx_map(utils.toindex(dest_nodes), nbits) out_map_creator = lambda nbits: _build_idx_map(utils.toindex(dest_nodes), nbits)
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(src_node_frame=graph._get_node_frame(block_id),
src_node_frame=graph._get_node_frame(block_id),
dst_node_frame=graph._get_node_frame(block_id + 1), dst_node_frame=graph._get_node_frame(block_id + 1),
edge_frame=graph._get_edge_frame(block_id), edge_frame=graph._get_edge_frame(block_id),
message_func=message_func, message_func=message_func,
...@@ -638,7 +658,7 @@ def schedule_nodeflow_compute(graph, ...@@ -638,7 +658,7 @@ def schedule_nodeflow_compute(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_dest_nodes, var_nf, final_feat = _apply_with_accum(var_dest_nodes, var_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_dest_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_nf, var_dest_nodes, final_feat)
...@@ -680,7 +700,7 @@ def _standardize_func_usage(func, func_name): ...@@ -680,7 +700,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s' % (func_name, str(func))) ' Got: %s' % (func_name, str(func)))
return func return func
def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func): def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
"""Apply with accumulated features. """Apply with accumulated features.
Paramters Paramters
...@@ -702,7 +722,7 @@ def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func): ...@@ -702,7 +722,7 @@ def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
v_nf = ir.UPDATE_DICT(v_nf, var_accum) v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, var_nodes.data, node_data) nbatch = NodeBatch(var_nodes.data, node_data)
return apply_func(nbatch) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
...@@ -716,7 +736,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -716,7 +736,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : GraphAdaptor
reduce_func : callable reduce_func : callable
edge_tuples : tuple of utils.Index edge_tuples : tuple of utils.Index
recv_nodes : utils.Index recv_nodes : utils.Index
...@@ -734,16 +754,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -734,16 +754,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame. # node frame.
# TODO(minjie): should replace this with an IR call to make the program # TODO(minjie): should replace this with an IR call to make the program
# stateless. # stateless.
tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes))) tmpframe = FrameRef(frame_like(graph.dstframe._frame, len(recv_nodes)))
# vars # vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg') var_msg = var.FEAT_DICT(graph.msgframe, 'msg')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, 'nf')
var_out = var.FEAT_DICT(data=tmpframe) var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list: if rfunc_is_list:
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv( adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes()) (src, dst, eid), graph.num_src(), graph.num_dst())
# using edge map instead of message map because messages are in global # using edge map instead of message map because messages are in global
# message frame # message frame
var_out_map = _build_idx_map(recv_nodes, nbits) var_out_map = _build_idx_map(recv_nodes, nbits)
...@@ -757,12 +777,11 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -757,12 +777,11 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
return var_out return var_out
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes, db.gen_degree_bucketing_schedule(rfunc, eid, dst, recv_nodes,
var_dst_nf, var_msg, var_out) var_dst_nf, var_msg, var_out)
return var_out return var_out
def _gen_send_reduce( def _gen_send_reduce(
graph,
src_node_frame, src_node_frame,
dst_node_frame, dst_node_frame,
edge_frame, edge_frame,
...@@ -793,8 +812,6 @@ def _gen_send_reduce( ...@@ -793,8 +812,6 @@ def _gen_send_reduce(
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph
src_node_frame : NodeFrame src_node_frame : NodeFrame
The node frame of the source nodes. The node frame of the source nodes.
dst_node_frame : NodeFrame dst_node_frame : NodeFrame
...@@ -899,7 +916,7 @@ def _gen_send_reduce( ...@@ -899,7 +916,7 @@ def _gen_send_reduce(
edge_map=edge_map) edge_map=edge_map)
else: else:
# generate UDF send schedule # generate UDF send schedule
var_mf = _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, var_u, var_mf = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc) var_v, var_eid, mfunc)
# 6. Generate reduce # 6. Generate reduce
...@@ -916,18 +933,18 @@ def _gen_send_reduce( ...@@ -916,18 +933,18 @@ def _gen_send_reduce(
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) mid = utils.toindex(slice(0, len(var_v.data)))
db.gen_degree_bucketing_schedule(graph, rfunc, mid, var_v.data, db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf, reduce_nodes, var_dst_nf, var_mf,
var_out) var_out)
return var_out return var_out
def _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc): def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc):
"""Internal function to generate send schedule for UDF message function.""" """Internal function to generate send schedule for UDF message function."""
fdsrc = ir.READ_ROW(var_src_nf, u) fdsrc = ir.READ_ROW(var_src_nf, u)
fddst = ir.READ_ROW(var_dst_nf, v) fddst = ir.READ_ROW(var_dst_nf, v)
fdedge = ir.READ_ROW(var_ef, eid) fdedge = ir.READ_ROW(var_ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data): def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch(graph, (u.data, v.data, eid.data), ebatch = EdgeBatch((u.data, v.data, eid.data),
src_data, edge_data, dst_data) src_data, edge_data, dst_data)
return mfunc(ebatch) return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper) _mfunc_wrapper = var.FUNC(_mfunc_wrapper)
...@@ -944,15 +961,15 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -944,15 +961,15 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if mfunc_is_list: if mfunc_is_list:
if eid.is_slice(0, graph._number_of_edges()): if eid.is_slice(0, graph.num_edges()):
# full graph case # full graph case
res = spmv.build_gidx_and_mapping_graph(graph) res = spmv.build_gidx_and_mapping_graph(graph)
else: else:
res = spmv.build_gidx_and_mapping_uv( res = spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes()) (u, v, eid), graph.num_src(), graph.num_dst())
adj, edge_map, _ = res adj, edge_map, _ = res
# create a tmp message frame # create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid))) tmp_mfr = FrameRef(frame_like(graph.edgeframe._frame, len(eid)))
var_out = var.FEAT_DICT(data=tmp_mfr) var_out = var.FEAT_DICT(data=tmp_mfr)
spmv.gen_v2e_spmv_schedule(graph=adj, spmv.gen_v2e_spmv_schedule(graph=adj,
mfunc=mfunc, mfunc=mfunc,
...@@ -964,7 +981,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -964,7 +981,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
edge_map=edge_map) edge_map=edge_map)
else: else:
# UDF send # UDF send
var_out = _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, var_u, var_out = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc) var_v, var_eid, mfunc)
return var_out return var_out
...@@ -1004,5 +1021,4 @@ def _build_idx_map(idx, nbits): ...@@ -1004,5 +1021,4 @@ def _build_idx_map(idx, nbits):
old_to_new = F.zerocopy_to_dgl_ndarray(old_to_new) old_to_new = F.zerocopy_to_dgl_ndarray(old_to_new)
return utils.CtxCachedObject(lambda ctx: nd.array(old_to_new, ctx=ctx)) return utils.CtxCachedObject(lambda ctx: nd.array(old_to_new, ctx=ctx))
_init_api("dgl.runtime.scheduler") _init_api("dgl.runtime.scheduler")
...@@ -6,8 +6,7 @@ from ..base import DGLError ...@@ -6,8 +6,7 @@ from ..base import DGLError
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .. import ndarray as nd from .. import ndarray as nd
from ..graph_index import GraphIndex from ..heterograph_index import create_unitgraph_from_coo
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph): ...@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters Parameters
---------- ----------
graph : DGLGraph or DGLHeteroGraph graph : GraphAdapter
The homogeneous graph, or a bipartite view of the heterogeneous graph. Graph
Returns Returns
------- -------
...@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph): ...@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph):
nbits : int nbits : int
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
gidx = graph._graph return graph.get_immutable_gidx, None, graph.bits_needed()
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges """Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of The matrix is of shape (num_src, num_dst).
nodes in the graph. Therefore, when doing SPMV, the src node data should be
all the node features.
Parameters Parameters
--------- ---------
edge_tuples : tuple of three utils.Index edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid) A tuple of (u, v, eid)
num_src, num_dst : int num_src : int
The number of source and destination nodes. Number of source nodes.
num_dst : int
Number of destination nodes.
Returns Returns
------- -------
...@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): ...@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
gidx = create_bipartite_from_coo(num_src, num_dst, u, v) gidx = create_unitgraph_from_coo(2, num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0) forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor() eid = eid.tousertensor()
nbits = gidx.bits_needed(0) nbits = gidx.bits_needed(0)
...@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): ...@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
edge_map = utils.CtxCachedObject( edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx), lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx))) nd.array(backward_map, ctx=ctx)))
return partial(gidx.get_bipartite, None), edge_map, nbits return partial(gidx.get_unitgraph, 0), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
"""Build immutable graph index and mapping for node flow """Build immutable graph index and mapping for node flow
......
"""User-defined function related data structures.""" """User-defined function related data structures."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import is_all
from . import backend as F
from . import utils
class EdgeBatch(object): class EdgeBatch(object):
"""The class that can represent a batch of edges. """The class that can represent a batch of edges.
Parameters Parameters
---------- ----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL The edge tuple (u, v, eid). eid can be ALL
src_data : dict src_data : dict
...@@ -24,8 +18,7 @@ class EdgeBatch(object): ...@@ -24,8 +18,7 @@ class EdgeBatch(object):
The dst node features, in the form of ``dict`` The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, edges, src_data, edge_data, dst_data): def __init__(self, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges self._edges = edges
self._src_data = src_data self._src_data = src_data
self._edge_data = edge_data self._edge_data = edge_data
...@@ -75,9 +68,6 @@ class EdgeBatch(object): ...@@ -75,9 +68,6 @@ class EdgeBatch(object):
destination node and the edge id for the ith edge destination node and the edge id for the ith edge
in the batch. in the batch.
""" """
if is_all(self._edges[2]):
self._edges = self._edges[:2] + (utils.toindex(F.arange(
0, self._g.number_of_edges())),)
u, v, eid = self._edges u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor()) return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
...@@ -104,9 +94,7 @@ class NodeBatch(object): ...@@ -104,9 +94,7 @@ class NodeBatch(object):
Parameters Parameters
---------- ----------
g : DGLGraph nodes : utils.Index
The graph object.
nodes : utils.Index or ALL
The node ids. The node ids.
data : dict data : dict
The node features, in the form of ``dict`` The node features, in the form of ``dict``
...@@ -115,8 +103,7 @@ class NodeBatch(object): ...@@ -115,8 +103,7 @@ class NodeBatch(object):
The messages, , in the form of ``dict`` The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, nodes, data, msgs=None): def __init__(self, nodes, data, msgs=None):
self._g = g
self._nodes = nodes self._nodes = nodes
self._data = data self._data = data
self._msgs = msgs self._msgs = msgs
...@@ -154,9 +141,6 @@ class NodeBatch(object): ...@@ -154,9 +141,6 @@ class NodeBatch(object):
tensor tensor
The nodes. The nodes.
""" """
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes()))
return self._nodes.tousertensor() return self._nodes.tousertensor()
def batch_size(self): def batch_size(self):
...@@ -166,10 +150,7 @@ class NodeBatch(object): ...@@ -166,10 +150,7 @@ class NodeBatch(object):
------- -------
int int
""" """
if is_all(self._nodes): return len(self._nodes)
return self._g.number_of_nodes()
else:
return len(self._nodes)
def __len__(self): def __len__(self):
"""Return the number of nodes in this node batch. """Return the number of nodes in this node batch.
......
...@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits): ...@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits):
return F.astype(tensor, F.int32) return F.astype(tensor, F.int32)
else: else:
return F.astype(tensor, F.int64) return F.astype(tensor, F.int64)
def make_invmap(array, use_numpy=True):
"""Find the unique elements of the array and return another array with indices
to the array of unique elements."""
if use_numpy:
uniques = np.unique(array)
else:
uniques = list(set(array))
invmap = {x: i for i, x in enumerate(uniques)}
remapped = np.array([invmap[x] for x in array])
return uniques, invmap, remapped
...@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError ...@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class NodeView(object): class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph. """A NodeView class to act as G.nodes for a DGLGraph.
...@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping): ...@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping):
data = self._graph.get_n_repr(self._nodes) data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame}) return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object): class EdgeView(object):
"""A EdgeView class to act as G.edges for a DGLGraph. """A EdgeView class to act as G.edges for a DGLGraph.
...@@ -256,111 +255,57 @@ class HeteroNodeView(object): ...@@ -256,111 +255,57 @@ class HeteroNodeView(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
def __getitem__(self, ntype): def __getitem__(self, key):
return HeteroNodeTypeView(self._graph, ntype) if isinstance(key, slice):
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice # slice
if not (nodes.start is None and nodes.stop is None if not (key.start is None and key.stop is None
and nodes.step is None): and key.step is None):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL)) nodes = ALL
ntype = None
elif isinstance(key, tuple):
nodes, ntype = key
elif isinstance(key, str):
nodes = ALL
ntype = key
else: else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes)) nodes = key
ntype = None
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, nodes))
def __call__(self): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
return F.arange(0, len(self)) return F.arange(0, self._graph.number_of_nodes(ntype))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
See Also class HeteroNodeDataView(MutableMapping):
-------- """The data view class when G.ndata[ntype] is called."""
dgl.DGLGraph.nodes __slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
"""
__slots__ = ['_graph', '_ntype', '_nodes']
def __init__(self, graph, ntype, nodes): def __init__(self, graph, ntype, nodes):
self._graph = graph self._graph = graph
self._ntype = ntype self._ntype = ntype
self._ntid = self._graph.get_ntype_id(ntype)
self._nodes = nodes self._nodes = nodes
def __getitem__(self, key): def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key] return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}) self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key) self._graph._pop_n_repr(self._ntid, key)
def __len__(self): def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]) return len(self._graph._node_frames[self._ntid])
def __iter__(self): def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]) return iter(self._graph._node_frames[self._ntid])
def __repr__(self): def __repr__(self):
data = self._graph.get_n_repr(self._ntype) data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key] return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]}) for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
...@@ -369,108 +314,59 @@ class HeteroEdgeView(object): ...@@ -369,108 +314,59 @@ class HeteroEdgeView(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
def __getitem__(self, etype): def __getitem__(self, key):
return HeteroEdgeTypeView(self._graph, etype) if isinstance(key, slice):
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice # slice
if not (edges.start is None and edges.stop is None if not (key.start is None and key.stop is None
and edges.step is None): and key.step is None):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL)) edges = ALL
etype = None
elif isinstance(key, tuple):
if len(key) == 3:
edges = ALL
etype = key
else:
edges = key
etype = None
elif isinstance(key, (str, tuple)):
edges = ALL
etype = key
else: else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges)) edges = key
etype = None
return EdgeSpace(data=HeteroEdgeDataView(self._graph, etype, edges))
def __call__(self): def __call__(self, *args, **kwargs):
"""Return the edges.""" """Return all the edges."""
return F.arange(0, len(self)) return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeTypeDataView(MutableMapping):
"""The data view class when G.edges[etype][...].data is called.
See Also class HeteroEdgeDataView(MutableMapping):
-------- """The data view class when G.ndata[etype] is called."""
dgl.DGLGraph.edges __slots__ = ['_graph', '_etype', '_etid', '_edges']
"""
__slots__ = ['_graph', '_etype', '_edges']
def __init__(self, graph, etype, edges): def __init__(self, graph, etype, edges):
self._graph = graph self._graph = graph
self._etype = etype self._etype = etype
self._etid = self._graph.get_etype_id(etype)
self._edges = edges self._edges = edges
def __getitem__(self, key): def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key] return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}) self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key) self._graph._pop_e_repr(self._etid, key)
def __len__(self): def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]) return len(self._graph._edge_frames[self._etid])
def __iter__(self): def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]) return iter(self._graph._edge_frames[self._etid])
def __repr__(self): def __repr__(self):
data = self._graph.get_e_repr(self._etype) data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key] return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]}) for key in self._graph._edge_frames[self._etid]})
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* \brief Heterograph implementation * \brief Heterograph implementation
*/ */
#include "./heterograph.h" #include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./bipartite.h" #include "./unit_graph.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// following heterograph: // following heterograph:
// //
// Meta graph: A -> B -> C // Meta graph: A -> B -> C
// Bipartite graphs: // UnitGraph graphs:
// * A -> B: (0, 0), (0, 1) // * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1) // * B -> C: (1, 0), (1, 1)
// //
...@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype); auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
subrels[etype] = Bipartite::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0], ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0], ret.induced_vertices[dst_vtype]->shape[0],
subedges[etype].src, subedges[etype].src,
...@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// Sanity check // Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graph must be bipartite graphs // all relation graphs must have only one edge type
for (const auto rg : rel_graphs) { for (const auto rg : rel_graphs) {
CHECK_EQ(rg->NumVertexTypes(), 2) << "Each relation graph must be a bipartite graph."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
} }
// create num verts per type // create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1); num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
...@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
dgl_type_t srctype = srctypes[i]; dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i]; dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i]; dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype];
const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
size_t nv; size_t nv;
// # nodes of source type // # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType); nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0) if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv; num_verts_per_type_[srctype] = nv;
else else
CHECK_EQ(num_verts_per_type_[srctype], nv) CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype; << "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type // # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType); nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0) if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv; num_verts_per_type_[dsttype] = nv;
else else
...@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con ...@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype); auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph( const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
{vids[src_vtype], vids[dst_vtype]}); std::vector<IdArray>({vids[src_vtype]}) :
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph; subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0]; ret.induced_edges[etype] = rel_vsg.induced_edges[0];
} }
...@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph( ...@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
} }
} }
// creator implementation FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const {
HeteroGraphPtr CreateBipartiteFromCOO( std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
int64_t num_src, int64_t num_dst, IdArray row, IdArray col) { size_t src_nodes = 0, dst_nodes = 0;
return Bipartite::CreateFromCOO(num_src, num_dst, row, col); std::vector<dgl_id_t> result_src, result_dst;
} std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
// loop.
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t num_srctype_nodes = NumVertices(srctype);
size_t num_dsttype_nodes = NumVertices(dsttype);
if (srctype_offsets.count(srctype) == 0) {
srctype_offsets[srctype] = num_srctype_nodes;
srctype_set.push_back(srctype);
}
if (dsttype_offsets.count(dsttype) == 0) {
dsttype_offsets[dsttype] = num_dsttype_nodes;
dsttype_set.push_back(dsttype);
}
}
// Sort the node types so that we can compare the sets and decide whether a homograph
// should be returned.
std::sort(srctype_set.begin(), srctype_set.end());
std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these
// two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype];
srctype_offsets[ntype] = src_nodes;
src_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_srctype.push_back(ntype);
induced_srcid.push_back(j);
}
}
for (size_t i = 0; i < dsttype_set.size(); ++i) {
dgl_type_t ntype = dsttype_set[i];
size_t num_nodes = dsttype_offsets[ntype];
dsttype_offsets[ntype] = dst_nodes;
dst_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_dsttype.push_back(ntype);
induced_dstid.push_back(j);
}
}
HeteroGraphPtr CreateBipartiteFromCSR( for (dgl_type_t etype : etypes) {
int64_t num_src, int64_t num_dst, auto src_dsttype = meta_graph_->FindEdge(etype);
IdArray indptr, IdArray indices, IdArray edge_ids) { dgl_type_t srctype = src_dsttype.first;
return Bipartite::CreateFromCSR(num_src, num_dst, indptr, indices, edge_ids); dgl_type_t dsttype = src_dsttype.second;
size_t srctype_offset = srctype_offsets[srctype];
size_t dsttype_offset = dsttype_offsets[dsttype];
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
result_dst.push_back(edges_dst_data[i] + dsttype_offset);
induced_etype.push_back(etype);
induced_eid.push_back(edges_eid_data[i]);
}
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->induced_srctype = aten::VecToIdArray(induced_srctype);
result->induced_srctype_set = aten::VecToIdArray(srctype_set);
result->induced_srcid = aten::VecToIdArray(induced_srcid);
result->induced_etype = aten::VecToIdArray(induced_etype);
result->induced_etype_set = aten::VecToIdArray(etypes);
result->induced_eid = aten::VecToIdArray(induced_eid);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype);
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set);
result->induced_dstid = aten::VecToIdArray(induced_dstid);
return FlattenedHeteroGraphPtr(result);
} }
// creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs)); return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
...@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t nvtypes = args[0];
int64_t num_dst = args[1]; int64_t num_src = args[1];
IdArray row = args[2]; int64_t num_dst = args[2];
IdArray col = args[3]; IdArray row = args[3];
auto hgptr = CreateBipartiteFromCOO(num_src, num_dst, row, col); IdArray col = args[4];
auto hgptr = UnitGraph::CreateFromCOO(nvtypes, num_src, num_dst, row, col);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t nvtypes = args[0];
int64_t num_dst = args[1]; int64_t num_src = args[1];
IdArray indptr = args[2]; int64_t num_dst = args[2];
IdArray indices = args[3]; IdArray indptr = args[3];
IdArray edge_ids = args[4]; IdArray indices = args[4];
auto hgptr = CreateBipartiteFromCSR(num_src, num_dst, indptr, indices, edge_ids); IdArray edge_ids = args[5];
auto hgptr = UnitGraph::CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
...@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") ...@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
*rv = HeteroGraphRef(hg->GetRelationGraph(etype)); if (hg->NumEdgeTypes() == 1) {
CHECK_EQ(etype, 0);
*rv = hg;
} else {
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes)
etypes_vec.push_back(val->data);
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
...@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") ...@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int bits = args[1]; int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits); HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
...@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx); HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
......
...@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph {
public: public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs); HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
uint64_t NumVertexTypes() const override {
return meta_graph_->NumVertices();
}
uint64_t NumEdgeTypes() const override {
return meta_graph_->NumEdges();
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype; CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
return relation_graphs_[etype]; return relation_graphs_[etype];
...@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
private: private:
/*! \brief A map from edge type to bipartite graph */ /*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_; std::vector<HeteroGraphPtr> relation_graphs_;
/*! \brief A map from vert type to the number of verts in the type */ /*! \brief A map from vert type to the number of verts in the type */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment