Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
......@@ -21,7 +21,9 @@ namespace dgl {
// Forward declaration
class BaseHeteroGraph;
class FlattenedHeteroGraph;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph;
/*!
......@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////////// query/operations on meta graph ////////////////////////
/*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const = 0;
virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}
/*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const = 0;
virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}
/*! \return the meta graph */
virtual GraphPtr meta_graph() const {
......@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}
static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
......@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);
};
/*! \brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object {
/*! \brief The graph */
HeteroGraphRef graph;
/*!
* \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_srctype;
/*!
* \brief The set of node types in parent graph appearing in source nodes.
*/
IdArray induced_srctype_set;
/*! \brief Mapping from source node ID to local node ID in parent graph */
IdArray induced_srcid;
/*!
* \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_etype;
/*!
* \brief The set of edge types in parent graph appearing in edges.
*/
IdArray induced_etype_set;
/*! \brief Mapping from edge ID to local edge ID in parent graph */
IdArray induced_eid;
/*!
* \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_dsttype;
/*!
* \brief The set of node types in parent graph appearing in destination nodes.
*/
IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */
IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set);
v->Visit("induced_srcid", &induced_srcid);
v->Visit("induced_etype", &induced_etype);
v->Visit("induced_etype_set", &induced_etype_set);
v->Visit("induced_eid", &induced_eid);
v->Visit("induced_dsttype", &induced_dsttype);
v->Visit("induced_dsttype_set", &induced_dsttype_set);
v->Visit("induced_dstid", &induced_dstid);
}
static constexpr const char* _type_key = "graph.FlattenedHeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(FlattenedHeteroGraph, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
......
......@@ -18,6 +18,7 @@ namespace runtime {
// forward declaration
class Object;
class ObjectRef;
class NDArray;
/*!
* \brief Visitor class to each object attribute.
......@@ -33,6 +34,7 @@ class AttrVisitor {
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
......
......@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
from ._ffi.base import DGLError, __version__
from .base import ALL
from .base import ALL, NTYPE, NID, ETYPE, EID
from .backend import load_backend
from .batched_graph import *
from .convert import *
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import *
......
......@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api
# A special symbol for selecting all nodes or edges.
ALL = "__ALL__"
# An alias for [:]
SLICE_FULL = slice(None, None, None)
# Reserved column names for storing parent node/edge types and IDs in flattened heterographs
NTYPE = '_TYPE'
NID = '_ID'
ETYPE = '_TYPE'
EID = '_ID'
def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges."""
......
"""Module for converting graph from/to other object."""
from collections import defaultdict
import numpy as np
import scipy as sp
import networkx as nx
from . import backend as F
from . import heterograph_index
from .heterograph import DGLHeteroGraph, combine_frames
from . import graph_index
from . import utils
from .base import NTYPE, ETYPE, NID, EID
__all__ = [
'graph',
'bipartite',
'hetero_from_relations',
'to_hetero',
'to_homo',
'to_networkx',
]
def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
"""Create a graph.
The graph has only one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
whose adjacency matrix must be square while :func:`dgl.bipartite`
creates a graph that does not necessarily have square adjacency matrix.
Examples
--------
Create from edges pairs:
>>> g = dgl.graph([(0, 2), (0, 3), (1, 2)])
Creat from pair of vertex IDs lists
>>> u = [0, 0, 1]
>>> v = [2, 3, 2]
>>> g = dgl.graph((u, v))
The IDs can also be stored in framework-specific tensors
>>> import torch
>>> u = torch.tensor([0, 0, 1])
>>> v = torch.tensor([2, 3, 2])
>>> g = dgl.graph((u, v))
Create from scipy sparse matrix
>>> from scipy.sparse import coo_matrix
>>> spmat = coo_matrix(([1,1,1], ([0, 0, 1], [2, 3, 2])), shape=(4, 4))
>>> g = dgl.graph(spmat)
Create from networkx graph
>>> import networkx as nx
>>> nxg = nx.path_graph(3)
>>> g = dgl.graph(nxg)
Specify node and edge type names
>>> g = dgl.graph(..., 'user', 'follows')
>>> g.ntypes
['user']
>>> g.etypes
['follows']
>>> g.canonical_etypes
[('user', 'follows', 'user')]
Parameters
----------
data : graph data
Data to initialize graph structure. Supported data formats are
(1) list of edge pairs (e.g. [(0, 2), (3, 1), ...])
(2) pair of vertex IDs representing end nodes (e.g. ([0, 3, ...], [2, 1, ...]))
(3) scipy sparse matrix
(4) networkx graph
ntype : str, optional
Node type name. (Default: _N)
etype : str, optional
Edge type name. (Default: _E)
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data.
(Default: None)
kwargs : key-word arguments, optional
Other key word arguments.
Returns
-------
DGLHeteroGraph
"""
if card is not None:
urange, vrange = card, card
else:
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange)
elif isinstance(data, list):
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, ntype, etype, ntype)
elif isinstance(data, nx.Graph):
return create_from_networkx(data, ntype, etype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
to ``vtype`` nodes. Nodes of each type have their own ID counts.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
whose adjacency matrix must be square while :func:`dgl.bipartite`
creates a graph that does not necessarily have square adjacency matrix.
Examples
--------
Create from edges pairs:
>>> g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game')
>>> g.ntypes
['user', 'game']
>>> g.etypes
['plays']
>>> g.canonical_etypes
[('user', 'plays', 'game')]
>>> g.number_of_nodes('user')
2
>>> g.number_of_nodes('game')
4
>>> g.number_of_edges('plays') # 'plays' could be omitted here
3
Creat from pair of vertex IDs lists
>>> u = [0, 0, 1]
>>> v = [2, 3, 2]
>>> g = dgl.bipartite((u, v))
The IDs can also be stored in framework-specific tensors
>>> import torch
>>> u = torch.tensor([0, 0, 1])
>>> v = torch.tensor([2, 3, 2])
>>> g = dgl.bipartite((u, v))
Create from scipy sparse matrix. Since scipy sparse matrix has explicit
shape, the cardinality of the result graph is derived from that.
>>> from scipy.sparse import coo_matrix
>>> spmat = coo_matrix(([1,1,1], ([0, 0, 1], [2, 3, 2])), shape=(4, 4))
>>> g = dgl.bipartite(spmat, 'user', 'plays', 'game')
>>> g.number_of_nodes('user')
4
>>> g.number_of_nodes('game')
4
Create from networkx graph. The given graph must follow the bipartite
graph convention in networkx. Each node has a ``bipartite`` attribute
with values 0 or 1. The result graph has two types of nodes and only
edges from ``bipartite=0`` to ``bipartite=1`` will be included.
>>> import networkx as nx
>>> nxg = nx.complete_bipartite_graph(3, 4)
>>> g = dgl.graph(nxg, 'user', 'plays', 'game')
>>> g.number_of_nodes('user')
3
>>> g.number_of_nodes('game')
4
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Parameters
----------
data : graph data
Data to initialize graph structure. Supported data formats are
(1) list of edge pairs (e.g. [(0, 2), (3, 1), ...])
(2) pair of vertex IDs representing end nodes (e.g. ([0, 3, ...], [2, 1, ...]))
(3) scipy sparse matrix
(4) networkx graph
utype : str, optional
Source node type name. (Default: _U)
etype : str, optional
Edge type name. (Default: _E)
vtype : str, optional
Destination node type name. (Default: _V)
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
infer from input data. (Default: None)
kwargs : key-word arguments, optional
Other key word arguments.
Returns
-------
DGLHeteroGraph
"""
if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
if card is not None:
urange, vrange = card
else:
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
elif isinstance(data, list):
return create_from_edge_list(data, utype, etype, vtype, urange, vrange)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, utype, etype, vtype)
elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite(data, utype, etype, vtype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))
def hetero_from_relations(rel_graphs):
"""Create a heterograph from per-relation graphs.
TODO(minjie): this API can be generalized as a union operation of
the input graphs
TODO(minjie): handle node/edge data
Parameters
----------
rel_graphs : list of DGLHeteroGraph
Graph for each relation.
Returns
-------
DGLHeteroGraph
A heterograph.
"""
# infer meta graph
ntype_dict = {} # ntype -> ntid
meta_edges = []
ntypes = []
etypes = []
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
if not stype in ntype_dict:
ntype_dict[stype] = len(ntypes)
ntypes.append(stype)
stid = ntype_dict[stype]
if not dtype in ntype_dict:
ntype_dict[dtype] = len(ntypes)
ntypes.append(dtype)
dtid = ntype_dict[dtype]
meta_edges.append((stid, dtid))
etypes.append(etype)
metagraph = graph_index.from_edge_list(meta_edges, True, True)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs])
retg = DGLHeteroGraph(hgidx, ntypes, etypes)
for i, rgrh in enumerate(rel_graphs):
for ntype in rgrh.ntypes:
retg.nodes[ntype].data.update(rgrh.nodes[ntype].data)
retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given graph to a heterogeneous graph.
The input graph should have only one type of nodes and edges. Each node and edge
stores an integer feature (under ``ntype_field`` and ``etype_field``), representing
the type id, which which can be used to retrieve the type names stored
in the given ``ntypes`` and ``etypes`` arguments.
Examples
--------
TBD
Parameters
----------
G : DGLHeteroGraph
Input homogenous graph.
ntypes : list of str
The node type names.
etypes : list of str
The edge type names.
ntype_field : str, optional
The feature field used to store node type. (Default: dgl.NTYPE)
etype_field : str, optional
The feature field used to store edge type. (Default: dgl.ETYPE)
metagraph : networkx MultiDiGraph, optional
Metagraph of the returned heterograph.
If provided, DGL assumes that G can indeed be described with the given metagraph.
If None, DGL will infer the metagraph from the given inputs, which would be
potentially slower for large graphs.
Returns
-------
DGLHeteroGraph
A heterograph.
The parent node and edge ID are stored in the column dgl.NID and dgl.EID
respectively for all node/edge types.
Notes
-----
The returned node and edge types may not necessarily be in the same order as
``ntypes`` and ``etypes``. And edge types may be duplicated if the source
and destination types differ.
The node IDs of a single type in the returned heterogeneous graph is ordered
the same as the nodes with the same ``ntype_field`` feature. Edge IDs of
a single type is similar.
"""
# TODO(minjie): use hasattr to support DGLGraph input; should be fixed once
# DGLGraph is merged with DGLHeteroGraph
if (hasattr(G, 'ntypes') and len(G.ntypes) > 1
or hasattr(G, 'etypes') and len(G.etypes) > 1):
raise DGLError('The input graph should be homogenous and have only one '
' type of nodes and edges.')
num_ntypes = len(ntypes)
ntype_ids = F.asnumpy(G.ndata[ntype_field])
etype_ids = F.asnumpy(G.edata[etype_field])
# relabel nodes to per-type local IDs
ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)
ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0)
ntype_ids_sortidx = np.argsort(ntype_ids)
ntype_local_ids = np.zeros_like(ntype_ids)
node_groups = []
for i in range(num_ntypes):
node_group = ntype_ids_sortidx[ntype_offset[i]:ntype_offset[i+1]]
node_groups.append(node_group)
ntype_local_ids[node_group] = np.arange(ntype_count[i])
src, dst = G.all_edges(order='eid')
src = F.asnumpy(src)
dst = F.asnumpy(dst)
src_local = ntype_local_ids[src]
dst_local = ntype_local_ids[dst]
srctype_ids = ntype_ids[src]
dsttype_ids = ntype_ids[dst]
canon_etype_ids = np.stack([srctype_ids, etype_ids, dsttype_ids], 1)
# infer metagraph
if metagraph is None:
canonical_etids, _, etype_remapped = \
utils.make_invmap(list(tuple(_) for _ in canon_etype_ids), False)
etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None])
else:
ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}
etypes_invmap = {et: i for i, et in enumerate(etypes)}
canonical_etids = []
etype_remapped = np.zeros(etype_ids)
for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)):
srctype_id = ntypes_invmap[srctype]
etype_id = etypes_invmap[etype]
dsttype_id = ntypes_invmap[dsttype]
canonical_etids.append((srctype_id, etype_id, dsttype_id))
canonical_etids = np.array(canonical_etids)
etype_mask = (canon_etype_ids[None, :] == canonical_etids[:, None]).all(2)
edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))]
rel_graphs = []
for i, (stid, etid, dtid) in enumerate(canonical_etids):
src_of_etype = src_local[edge_groups[i]]
dst_of_etype = dst_local[edge_groups[i]]
if stid == dtid:
rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
card=ntype_count[stid])
else:
rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
card=(ntype_count[stid], ntype_count[dtid]))
rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs)
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.tensor(ntype2ngrp[ntype])
for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[etid][EID] = F.tensor(edge_groups[etid])
# features
for key, data in G.ndata.items():
for ntid, ntype in enumerate(hg.ntypes):
rows = F.copy_to(F.tensor(ntype2ngrp[ntype]), F.context(data))
hg._node_frames[ntid][key] = F.gather_row(data, rows)
for key, data in G.edata.items():
for etid in range(len(hg.canonical_etypes)):
rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))
hg._edge_frames[etid][key] = F.gather_row(data, rows)
return hg
def to_homo(G):
"""Convert the given graph to a homogeneous graph.
The returned graph has only one type of nodes and etypes.
Node and edge types are stored as features in the returned graph. Each feature
is an integer representing the type id, which can be used to retrieve the type
names stored in ``G.ntypes`` and ``G.etypes`` arguments.
Examples
--------
TBD
Parameters
----------
G : DGLHeteroGraph
Input heterogenous graph.
Returns
-------
DGLHeteroGraph
A homogenous graph.
The parent node and edge type/ID are stored in columns dgl.NTYPE/dgl.NID and
dgl.ETYPE/dgl.EID respectively.
"""
num_nodes_per_ntype = [G.number_of_nodes(ntype) for ntype in G.ntypes]
offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)
srcs = []
dsts = []
etype_ids = []
eids = []
ntype_ids = []
nids = []
for ntype_id, ntype in enumerate(G.ntypes):
num_nodes = G.number_of_nodes(ntype)
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes))
for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype
src, dst = G.all_edges(etype=etype, order='eid')
num_edges = len(src)
srcs.append(src + offset_per_ntype[G.get_ntype_id(srctype)])
dsts.append(dst + offset_per_ntype[G.get_ntype_id(dsttype)])
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)))
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
retg.edata[EID] = F.cat(eids, 0)
# features
comb_nf = combine_frames(G._node_frames, range(len(G.ntypes)))
comb_ef = combine_frames(G._edge_frames, range(len(G.etypes)))
if comb_nf is not None:
retg.ndata.update(comb_nf)
if comb_ef is not None:
retg.edata.update(comb_ef)
return retg
############################################################
# Internal APIs
############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
Parameters
----------
u : iterable of int
List of source node IDs.
v : iterable of int
List of destination node IDs.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
urange : int, optional
The source node ID range. If None, the value is the maximum
of the source node IDs in the edge list plus 1. (Default: None)
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
Returns
------
DGLHeteroGraph
"""
u = utils.toindex(u)
v = utils.toindex(v)
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
if utype == vtype:
urange = vrange = max(urange, vrange)
num_ntypes = 1
else:
num_ntypes = 2
hgidx = heterograph_index.create_unitgraph_from_coo(num_ntypes, urange, vrange, u, v)
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
"""Internal function to create a graph from a list of edge tuples with types.
utype could be equal to vtype
Examples
--------
TBD
Parameters
----------
elist : iterable of int pairs
List of (src, dst) node ID pairs.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
urange : int, optional
The source node ID range. If None, the value is the maximum
of the source node IDs in the edge list plus 1. (Default: None)
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
Returns
-------
DGLHeteroGraph
"""
if len(elist) == 0:
u, v = [], []
else:
u, v = zip(*elist)
u = list(u)
v = list(v)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
"""Internal function to create a graph from a scipy sparse matrix with types.
Parameters
----------
spmat : scipy.sparse.spmatrix
The adjacency matrix whose rows represent sources and columns
represent destinations.
utype : str
Source node type name.
etype : str
Edge type name.
vtype : str
Destination node type name.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
Returns
-------
DGLHeteroGraph
"""
num_src, num_dst = spmat.shape
num_ntypes = 1 if utype == vtype else 2
if spmat.getformat() == 'coo':
row = utils.toindex(spmat.row)
col = utils.toindex(spmat.col)
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, num_src, num_dst, row, col)
else:
spmat = spmat.tocsr()
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
# TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix?
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, num_src, num_dst, indptr, indices, data)
if num_ntypes == 1:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_networkx(nx_graph,
ntype, etype,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Create graph that has only one set of nodes and edges.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
# Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
# nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
src = utils.toindex(src)
dst = utils.toindex(dst)
num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes)
# handle features
# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else:
return F.tensor(lst)
if node_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(list)
for nid in range(g.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
g.ndata[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(lambda: [None] * g.number_of_edges())
# each defaultdict value is initialized to be a list of None
# None here serves as placeholder to be replaced by feature with
# corresponding edge id
if has_edge_id:
num_edges = g.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True):
if attrs[edge_id_attr_name] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be'
' smaller than the number of edges --'
' {}, got {}.'.format(num_edges, attrs['id']))
for key in edge_attrs:
attr_dict[key][attrs['id']] = attrs[key]
else:
# XXX: assuming networkx iteration order is deterministic
# so the order is the same as graph_index.from_networkx
for eid, (_, _, attrs) in enumerate(nx_graph.edges(data=True)):
for key in edge_attrs:
attr_dict[key][eid] = attrs[key]
for attr in edge_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
g.edata[attr] = _batcher(attr_dict[attr])
return g
def create_from_networkx_bipartite(nx_graph,
utype, etype, vtype,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Create graph that has only one set of nodes and edges.
The input graph must follow the bipartite graph convention of networkx.
Each node has an attribute ``bipartite`` with values 0 and 1 indicating which
set it belongs to.
Only edges from node set 0 to node set 1 are added to the returned graph.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
top_nodes = {n for n, d in nx_graph.nodes(data=True) if d['bipartite'] == 0}
bottom_nodes = set(nx_graph) - top_nodes
top_nodes = sorted(top_nodes)
bottom_nodes = sorted(bottom_nodes)
top_map = {n : i for i, n in enumerate(top_nodes)}
bottom_map = {n : i for i, n in enumerate(bottom_nodes)}
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = top_map[u]
dst[eid] = bottom_map[v]
else:
src = []
dst = []
for e in nx_graph.edges:
if e[0] in top_map:
src.append(top_map[e[0]])
dst.append(bottom_map[e[1]])
src = utils.toindex(src)
dst = utils.toindex(dst)
g = create_from_edges(src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes))
# TODO attributes
assert node_attrs is None
assert edge_attrs is None
return g
def to_networkx(g, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
See Also
--------
DGLHeteroGraph.to_networkx
"""
return g.to_networkx(node_attrs, edge_attrs)
......@@ -186,8 +186,8 @@ class Frame(MutableMapping):
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows``
will be ignored and inferred from the given data.
The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data.
"""
def __init__(self, data=None, num_rows=0):
if data is None:
......@@ -202,7 +202,7 @@ class Frame(MutableMapping):
elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = 0
self._num_rows = num_rows
# sanity check
for name, col in self._columns.items():
if len(col) != self._num_rows:
......@@ -880,23 +880,23 @@ class FrameRef(MutableMapping):
"""
return self._index.get_items(query)
def frame_like(other, num_rows):
"""Create a new frame that has the same scheme as the given one.
def frame_like(other, num_rows=None):
"""Create an empty frame that has the same initializer as the given one.
Parameters
----------
other : Frame
The given frame.
num_rows : int
The number of rows of the new one.
The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns
-------
Frame
The new frame.
"""
# TODO(minjie): scheme is not inherited at the moment. Fix this
# when moving per-col initializer to column scheme.
num_rows = other.num_rows if num_rows is None else num_rows
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
......
......@@ -11,7 +11,7 @@ from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index
from .runtime import ir, scheduler, Runtime
from .runtime import ir, scheduler, Runtime, GraphAdapter
from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
......@@ -49,14 +49,6 @@ class DGLBaseGraph(object):
"""
return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
......@@ -73,10 +65,6 @@ class DGLBaseGraph(object):
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self):
"""Return the number of edges in the graph.
......@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index):
self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph):
else:
v = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self,
v=v,
scheduler.schedule_apply_nodes(v=v,
apply_func=func,
node_frame=self._node_frame,
inplace=inplace)
Runtime.run(prog)
......@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self,
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
scheduler.schedule_apply_edges(AdaptedDGLGraph(self), u, v, eid, func, inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
......@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(graph=self,
u=u,
v=v,
eid=eid,
scheduler.schedule_group_apply_edge(graph=AdaptedDGLGraph(self),
u=u, v=v, eid=eid,
apply_func=func,
group_by=group_by,
inplace=inplace)
......@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
scheduler.schedule_send(graph=AdaptedDGLGraph(self), u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)
......@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self,
scheduler.schedule_recv(graph=AdaptedDGLGraph(self),
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
......@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_snr(graph=self,
scheduler.schedule_snr(graph=AdaptedDGLGraph(self),
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
......@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph):
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self,
scheduler.schedule_pull(graph=AdaptedDGLGraph(self),
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
......@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph):
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self,
scheduler.schedule_push(graph=AdaptedDGLGraph(self),
u=u,
message_func=message_func,
reduce_func=reduce_func,
......@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph):
assert reduce_func is not None
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self,
scheduler.schedule_update_all(graph=AdaptedDGLGraph(self),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
......@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph):
v = utils.toindex(nodes)
n_repr = self.get_n_repr(v)
nbatch = NodeBatch(self, v, n_repr)
nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes):
......@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph):
filter_nodes
"""
if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
......@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph):
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges):
......@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph):
yield
self._node_frame = old_nframe
self._edge_frame = old_eframe
############################################################
# Internal APIs
############################################################
class AdaptedDGLGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLGraph
Graph
"""
def __init__(self, graph):
self.graph = graph
@property
def gidx(self):
return self.graph._graph
def num_src(self):
"""Number of source nodes."""
return self.graph.number_of_nodes()
def num_dst(self):
"""Number of destination nodes."""
return self.graph.number_of_nodes()
def num_edges(self):
"""Number of edges."""
return self.graph.number_of_edges()
@property
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frame
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frame
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index()
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(val)
def in_edges(self, nodes):
return self.graph._graph.in_edges(nodes)
def out_edges(self, nodes):
return self.graph._graph.out_edges(nodes)
def edges(self, form):
return self.graph._graph.edges(form)
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_immutable_gidx(ctx)
def bits_needed(self):
return self.graph._graph.bits_needed()
......@@ -1129,10 +1129,13 @@ def from_edge_list(elist, is_multigraph, readonly):
Parameters
---------
elist : list
List of (u, v) edge tuple.
elist : list, tuple
List of (u, v) edge tuple, or a tuple of src/dst lists
"""
src, dst = zip(*elist)
if isinstance(elist, tuple):
src, dst = elist
else:
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
src_ids = utils.toindex(src)
......
"""Classes for heterogeneous graphs."""
from collections import defaultdict
from contextlib import contextmanager
import networkx as nx
import scipy.sparse as ssp
from . import heterograph_index, graph_index
import numpy as np
from . import graph_index
from . import heterograph_index
from . import utils
from . import backend as F
from . import init
from .runtime import ir, scheduler, Runtime
from .frame import Frame, FrameRef
from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, is_all, DGLError
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError
__all__ = ['DGLHeteroGraph', 'combine_names']
class DGLHeteroGraph(object):
"""Base heterogeneous graph class.
Do NOT instantiate from this class directly; use :mod:`conversion methods
<dgl.convert>` instead.
A Heterogeneous graph is defined as a graph with node types and edge
types.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
Suppose that one maps the users, games and developers to the following
IDs:
User name Alice Bob Carol
User ID 0 1 2
Game name Tetris Minecraft
Game ID 0 1
Developer name Nintendo Mojang
Developer ID 0 1
One can construct the graph as follows:
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> dev_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, dev_g])
:func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
data types including: edge list, edge tuples, networkx graph and scipy sparse matrix.
Click the function name for more details.
Then one can query the graph structure by specifying the ``ntype`` or ``etype`` arguments:
>>> g.number_of_nodes('user')
3
>>> g.number_of_edges('plays')
4
>>> g.out_degrees(etype='develops') # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g.in_edges(0, etype='develops') # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
Or on the sliced graph for an edge type:
__all__ = ['DGLHeteroGraph']
>>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees()
tensor([1, 1])
>>> g['develops'].in_edges(0)
Node type names must be distinct (no two types have the same name). Edge types could
have the same name but they must be distinguishable by the ``(src_type, edge_type, dst_type)``
triplet (called *canonical edge type*).
For example, suppose a graph that has two types of relation "user-watches-movie"
and "user-watches-TV" as follows:
>>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
>>> GG = dgl.hetero_from_relations([g0, g1])
To distinguish between the two "watches" edge type, one must specify a full triplet:
>>> GG.number_of_edges(('user', 'watches', 'movie'))
3
>>> GG.number_of_edges(('user', 'watches', 'TV'))
2
>>> GG['user', 'watches', 'movie'].out_degrees()
tensor([1, 2])
Using only one single edge type string "watches" is ambiguous and will cause error:
>>> GG.number_of_edges('watches') # AMBIGUOUS!!
In many cases, there is only one type of nodes or one type of edges, and the ``ntype``
and ``etype`` argument could be omitted. This is very common when using the sliced
graph, which usually contains only one edge type, and sometimes only one node type:
# TODO: depending on the progress of unifying DGLGraph and Bipartite, we may or may not
# need the code of heterogeneous graph views.
# pylint: disable=unnecessary-pass
class DGLBaseHeteroGraph(object):
"""Base Heterogeneous graph class.
>>> g['follows'].number_of_nodes() # OK!! because g['follows'] only has one node type 'user'
3
>>> g['plays'].number_of_nodes() # ERROR!! There are two types 'user' and 'game'.
>>> g['plays'].number_of_edges() # OK!! because there is only one edge type 'plays'
Parameters
----------
graph : graph index, optional
The graph index
ntypes : list[str]
The node type names
etypes : list[str]
The edge type names
_ntypes_invmap, _etypes_invmap, _view_ntype_idx, _view_etype_idx :
Internal arguments
"""
gidx : HeteroGraphIndex
Graph index object.
ntypes : list of str
Node type list. The i^th element stores the type name
of node type i.
etypes : list of str
Edge type list. The i^th element stores the type name
of edge type i.
node_frames : list of FrameRef, optional
Node feature storage. The i^th element stores the node features
of node type i. If None, empty frame is created. (default: None)
edge_frames : list of FrameRef, optional
Edge feature storage. The i^th element stores the edge features
of edge type i. If None, empty frame is created. (default: None)
multigraph : bool, optional
Whether the graph would be a multigraph. If none, the flag will be determined
by scanning the whole graph. (default: None)
readonly : bool, optional
Whether the graph structure is read-only (default: True).
Notes
-----
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument
def __init__(self, graph, ntypes, etypes,
_ntypes_invmap=None, _etypes_invmap=None,
_view_ntype_idx=None, _view_etype_idx=None):
super(DGLBaseHeteroGraph, self).__init__()
def __init__(self,
gidx,
ntypes,
etypes,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True):
assert readonly, "Only readonly heterogeneous graphs are supported"
self._graph = graph
self._graph = gidx
self._nx_metagraph = None
self._ntypes = ntypes
self._etypes = etypes
# inverse mapping from ntype str to int
self._ntypes_invmap = _ntypes_invmap or \
{ntype: i for i, ntype in enumerate(ntypes)}
# inverse mapping from etype str to int
self._etypes_invmap = _etypes_invmap or \
{etype: i for i, etype in enumerate(etypes)}
self._canonical_etypes = make_canonical_etypes(etypes, ntypes, self._graph.metagraph)
# An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicte ambiguity.
self._etype2canonical = {}
for i, ety in enumerate(etypes):
if ety in self._etype2canonical:
self._etype2canonical[ety] = tuple()
else:
self._etype2canonical[ety] = self._canonical_etypes[i]
self._ntypes_invmap = {t : i for i, t in enumerate(ntypes)}
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
# Indicates which node/edge type (int) it is viewing.
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
# node and edge frame
if node_frames is None:
node_frames = [None] * len(self._ntypes)
node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
if frame is None else frame
for i, frame in enumerate(node_frames)]
self._node_frames = node_frames
self._cache = {}
if edge_frames is None:
edge_frames = [None] * len(self._etypes)
edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
if frame is None else frame
for i, frame in enumerate(edge_frames)]
self._edge_frames = edge_frames
def _create_view(self, ntype_idx, etype_idx):
return DGLBaseHeteroGraph(
self._graph, self._ntypes, self._etypes,
self._ntypes_invmap, self._etypes_invmap,
ntype_idx, etype_idx)
# message indicators
self._msg_indices = [None] * len(self._etypes)
self._msg_frames = []
for i in range(len(self._etypes)):
frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
@property
def is_node_type_view(self):
"""Whether this is a node type view of a heterograph."""
return self._view_ntype_idx is not None
def _get_msg_index(self, etid):
if self._msg_indices[etid] is None:
self._msg_indices[etid] = utils.zero_index(
size=self._graph.number_of_edges(etid))
return self._msg_indices[etid]
@property
def is_edge_type_view(self):
"""Whether this is an edge type view of a heterograph."""
return self._view_etype_idx is not None
def _set_msg_index(self, etid, index):
self._msg_indices[etid] = index
@property
def is_view(self):
"""Whether this is a node/view of a heterograph."""
return self.is_node_type_view or self.is_edge_type_view
def __repr__(self):
if len(self.ntypes) == 1 and len(self.etypes) == 1:
ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
' ndata_schemes={ndata}\n'
' edata_schemes={edata})')
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
else:
ret = ('Graph(num_nodes={node},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))}
nedge_dict = {self.etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))}
meta = str(self.metagraph.edges())
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
#################################################################
# Mutation operations
#################################################################
def add_nodes(self, num, data=None, ntype=None):
"""Add multiple new nodes of the same node type
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
def add_edge(self, u, v, data=None, etype=None):
"""Add an edge of ``etype`` between u of the source node type, and v
of the destination node type..
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
def add_edges(self, u, v, data=None, etype=None):
"""Add multiple edges of ``etype`` between list of source nodes ``u``
and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
Currently not supported.
"""
raise DGLError('Mutation is not supported in heterograph.')
#################################################################
# Metagraph query
#################################################################
@property
def all_node_types(self):
"""Return the list of node types of the entire heterograph."""
def ntypes(self):
"""Return the list of node types of this graph."""
return self._ntypes
@property
def all_edge_types(self):
"""Return the list of edge types of the entire heterograph."""
def etypes(self):
"""Return the list of edge types of this graph."""
return self._etypes
@property
def canonical_etypes(self):
"""Return the list of canonical edge types of this graph.
A canonical edge type is a tuple of string (src_type, edge_type, dst_type).
"""
return self._canonical_etypes
@property
def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph.
The nodes are labeled with node type names.
The edges have their keys holding the edge type names.
Returns
-------
networkx.MultiDiGraph
"""
if self._nx_metagraph is None:
nx_graph = self._graph.metagraph.to_networkx()
self._nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
self._nx_metagraph.add_edge(srctype, dsttype, etype)
return self._nx_metagraph
def to_canonical_etype(self, etype):
"""Convert edge type to canonical etype: (srctype, etype, dsttype).
The input can already be a canonical tuple.
Parameters
----------
etype : str or tuple of str
Edge type
Returns
-------
tuple of str
"""
nx_graph = self._graph.metagraph.to_networkx()
nx_return_graph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
etype = self._etypes[nx_graph.edges[u_v]['id']]
srctype = self._ntypes[u_v[0]]
dsttype = self._ntypes[u_v[1]]
assert etype[0] == srctype
assert etype[2] == dsttype
nx_return_graph.add_edge(srctype, dsttype, etype[1])
return nx_return_graph
def _endpoint_types(self, etype):
"""Return the source and destination node type (int) of given edge
type (int)."""
return self._graph.metagraph.find_edge(etype)
def _node_types(self):
if self.is_node_type_view:
return [self._view_ntype_idx]
elif self.is_edge_type_view:
srctype_idx, dsttype_idx = self._endpoint_types(self._view_etype_idx)
return [srctype_idx, dsttype_idx] if srctype_idx != dsttype_idx else [srctype_idx]
if isinstance(etype, tuple):
return etype
else:
return range(len(self._ntypes))
ret = self._etype2canonical.get(etype, None)
if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical etype '
'type in the form of (srctype, etype, dsttype)' % etype)
return ret
def get_ntype_id(self, ntype):
"""Return the id of the given node type.
ntype can also be None. If so, there should be only one node type in the
graph.
def node_types(self):
"""Return the list of node types appearing in the current view.
Parameters
----------
ntype : str
Node type
Returns
-------
list[str]
List of node types
int
"""
if ntype is None:
if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
return 0
ntid = self._ntypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('Node type "{}" does not exist.'.format(ntype))
return ntid
Examples
--------
Getting all node types.
>>> g.node_types()
['user', 'game', 'developer']
Getting all node types appearing in the subgraph induced by "users"
(which should only yield "user").
>>> g['user'].node_types()
['user']
The node types appearing in subgraph induced by "plays" relationship,
which should only give "user" and "game".
>>> g['plays'].node_types()
['user', 'game']
"""
ntypes = self._node_types()
if isinstance(ntypes, range):
# assuming that the range object always covers the entire node type list
return self._ntypes
else:
return [self._ntypes[i] for i in ntypes]
def _edge_types(self):
if self.is_node_type_view:
etype_indices = self._graph.metagraph.edge_id(
self._view_ntype_idx, self._view_ntype_idx)
return etype_indices
elif self.is_edge_type_view:
return [self._view_etype_idx]
else:
return range(len(self._etypes))
def get_etype_id(self, etype):
"""Return the id of the given edge type.
def edge_types(self):
"""Return the list of edge types appearing in the current view.
etype can also be None. If so, there should be only one edge type in the
graph.
Parameters
----------
etype : str or tuple of str
Edge type
Returns
-------
list[str]
List of edge types
int
"""
if etype is None:
if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
return 0
etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
if etid is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
return etid
Examples
--------
Getting all edge types.
>>> g.edge_types()
['follows', 'plays', 'develops']
Getting all edge types appearing in subgraph induced by "users".
>>> g['user'].edge_types()
['follows']
The edge types appearing in subgraph induced by "plays" relationship,
which should only give "plays".
>>> g['plays'].edge_types()
['plays']
"""
etypes = self._edge_types()
if isinstance(etypes, range):
return self._etypes
else:
return [self._etypes[i] for i in etypes]
#################################################################
# View
#################################################################
@property
@utils.cached_member('_cache', '_current_ntype_idx')
def _current_ntype_idx(self):
"""Checks the uniqueness of node type in the view and get the index
of that node type.
def nodes(self):
"""Return a node view that can used to set/get feature data of a
single node type.
This allows reading/writing node frame data.
Examples
--------
To set features of all Users:
>>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
"""
node_types = self._node_types()
assert len(node_types) == 1, "only available for subgraphs with one node type"
return node_types[0]
return HeteroNodeView(self)
@property
@utils.cached_member('_cache', '_current_etype_idx')
def _current_etype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of that edge type.
def ndata(self):
"""Return the data view of all the nodes.
This allows reading/writing edge frame data and message passing routines.
Only works if the graph has only one node type.
Examples
--------
To set features of all nodes in a heterogeneous graph with only one node type:
>>> g.ndata['h'] = torch.zeros(2, 5)
"""
edge_types = self._edge_types()
assert len(edge_types) == 1, "only available for subgraphs with one edge type"
return edge_types[0]
return HeteroNodeDataView(self, None, ALL)
@property
@utils.cached_member('_cache', '_current_srctype_idx')
def _current_srctype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of the source type.
def edges(self):
"""Return an edges view that can used to set/get feature data of a
single edge type.
This allows reading/writing edge frame data and message passing routines.
Examples
--------
To set features of all "play" relationships:
>>> g.edges['plays'].data['h'] = torch.zeros(4, 4)
"""
srctype_idx, _ = self._endpoint_types(self._current_etype_idx)
return srctype_idx
return HeteroEdgeView(self)
@property
@utils.cached_member('_cache', '_current_dsttype_idx')
def _current_dsttype_idx(self):
"""Checks the uniqueness of edge type in the view and get the index
of the destination type.
def edata(self):
"""Return the data view of all the edges.
This allows reading/writing edge frame data and message passing routines.
Only works if the graph has only one edge type
Examples
--------
To set features of all edges in a heterogeneous graph with only one edge type:
>>> g.edata['h'] = torch.zeros(2, 5)
"""
_, dsttype_idx = self._endpoint_types(self._current_etype_idx)
return dsttype_idx
return HeteroEdgeDataView(self, None, ALL)
def _find_etypes(self, key):
etypes = [
i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
(key[0] == SLICE_FULL or key[0] == srctype) and
(key[1] == SLICE_FULL or key[1] == etype) and
(key[2] == SLICE_FULL or key[2] == dsttype)]
return etypes
def __getitem__(self, key):
"""Return the relation slice of this graph.
A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
slice (``:``) representing wildcard (i.e. any source/edge/destination type).
A relation slice is a homogeneous (with one node type and one edge type) or
bipartite (with two node types and one edge type) graph, transformed from
the original heterogeneous graph.
If there is only one canonical edge type found, then the returned relation
slice would be a subgraph induced from the original graph. That is, it is
equivalent to ``self.edge_type_subgraph(etype)``. The node and edge features
of the returned graph would be shared with thew original graph.
If there are multiple canonical edge type found, then the source/edge/destination
node types would be a *concatenation* of original node/edge types. The
new source/destination node type would have the concatenation determined by
:func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
types as its name. The source/destination node would be formed by concatenating the
common features of the original source/destination types, therefore they are not
shared with the original graph. Edge type is similar.
"""
err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
"to get view of one relation type. Use : to slice multiple types (e.g. " +\
"G['srctype', :, 'dsttype'])."
if not isinstance(key, tuple):
key = (SLICE_FULL, key, SLICE_FULL)
if len(key) != 3:
raise DGLError(err_msg)
etypes = self._find_etypes(key)
if len(etypes) == 1:
# no ambiguity: return the unitgraph itself
srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
stid = self.get_ntype_id(srctype)
etid = self.get_etype_id((srctype, etype, dsttype))
dtid = self.get_ntype_id(dsttype)
new_g = self._graph.get_relation_graph(etid)
if stid == dtid:
new_ntypes = [srctype]
new_nframes = [self._node_frames[stid]]
else:
new_ntypes = [srctype, dsttype]
new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
new_etypes = [etype]
new_eframes = [self._edge_frames[etid]]
def number_of_nodes(self, ntype):
return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
else:
flat = self._graph.flatten_relations(etypes)
new_g = flat.graph
# merge frames
stids = flat.induced_srctype_set.asnumpy()
dtids = flat.induced_dsttype_set.asnumpy()
etids = flat.induced_etype_set.asnumpy()
new_ntypes = [combine_names(self.ntypes, stids)]
if new_g.number_of_ntypes() == 2:
new_ntypes.append(combine_names(self.ntypes, dtids))
new_nframes = [
combine_frames(self._node_frames, stids),
combine_frames(self._node_frames, dtids)]
else:
assert np.array_equal(stids, dtids)
new_nframes = [combine_frames(self._node_frames, stids)]
new_etypes = [combine_names(self.etypes, etids)]
new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph
new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
# put the parent node/edge type and IDs
new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)
return new_hg
#################################################################
# Graph query
#################################################################
def number_of_nodes(self, ntype=None):
"""Return the number of nodes of the given type in the heterograph.
Parameters
----------
ntype : str
The node type
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns
-------
......@@ -250,40 +554,16 @@ class DGLBaseHeteroGraph(object):
>>> g['user'].number_of_nodes()
3
"""
return self._graph.number_of_nodes(self._ntypes_invmap[ntype])
return self._graph.number_of_nodes(self.get_ntype_id(ntype))
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_srctype_idx)
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self._graph.number_of_nodes(self._current_dsttype_idx)
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise.
"""
assert not self.is_view, 'not supported on views'
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self._graph.number_of_edges(self._current_etype_idx)
def number_of_edges(self, etype):
def number_of_edges(self, etype=None):
"""Return the number of edges of the given type in the heterograph.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -295,17 +575,28 @@ class DGLBaseHeteroGraph(object):
>>> g.number_of_edges(('user', 'plays', 'game'))
4
"""
return self._graph.number_of_edges(self._etypes_invmap[etype])
return self._graph.number_of_edges(self.get_etype_id(etype))
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise."""
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise."""
return self._graph.is_readonly()
def has_node(self, ntype, vid):
def has_node(self, vid, ntype=None):
"""Return True if the graph contains node `vid` of type `ntype`.
Parameters
----------
ntype : str
The node type.
vid : int
The node ID.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns
-------
......@@ -314,28 +605,29 @@ class DGLBaseHeteroGraph(object):
Examples
--------
>>> g.has_node('user', 0)
>>> g.has_node(0, 'user')
True
>>> g.has_node('user', 4)
>>> g.has_node(4, 'user')
False
See Also
--------
has_nodes
"""
return self._graph.has_node(self._ntypes_invmap[ntype], vid)
return self._graph.has_node(self.get_ntype_id(ntype), vid)
def has_nodes(self, ntype, vids):
def has_nodes(self, vids, ntype=None):
"""Return a 0-1 array ``a`` given the node ID array ``vids``.
``a[i]`` is 1 if the graph contains node ``vids[i]`` of type ``ntype``, 0 otherwise.
Parameters
----------
ntype : str
The node type.
vid : list or tensor
The array of node IDs.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns
-------
......@@ -346,7 +638,7 @@ class DGLBaseHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g.has_nodes('user', [0, 1, 2, 3, 4])
>>> g.has_nodes([0, 1, 2, 3, 4], 'user')
tensor([1, 1, 1, 0, 0])
See Also
......@@ -354,20 +646,21 @@ class DGLBaseHeteroGraph(object):
has_node
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(self._ntypes_invmap[ntype], vids)
rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
return rst.tousertensor()
def has_edge_between(self, etype, u, v):
def has_edge_between(self, u, v, etype=None):
"""Return True if the edge (u, v) of type ``etype`` is in the graph.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The node ID of source type.
v : int
The node ID of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -377,20 +670,20 @@ class DGLBaseHeteroGraph(object):
Examples
--------
Check whether Alice plays Tetris
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 1)
>>> g.has_edge_between(0, 1, ('user', 'plays', 'game'))
True
And whether Alice plays Minecraft
>>> g.has_edge_between(('user', 'plays', 'game'), 0, 2)
>>> g.has_edge_between(0, 2, ('user', 'plays', 'game'))
False
See Also
--------
has_edges_between
"""
return self._graph.has_edge_between(self._etypes_invmap[etype], u, v)
return self._graph.has_edge_between(self.get_etype_id(etype), u, v)
def has_edges_between(self, etype, u, v):
def has_edges_between(self, u, v, etype=None):
"""Return a 0-1 array ``a`` given the source node ID array ``u`` and
destination node ID array ``v``.
......@@ -398,12 +691,13 @@ class DGLBaseHeteroGraph(object):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor
The node ID array of source type.
v : list, tensor
The node ID array of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -414,7 +708,7 @@ class DGLBaseHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g.has_edges_between(('user', 'plays', 'game'), [0, 0], [1, 2])
>>> g.has_edges_between([0, 0], [1, 2], ('user', 'plays', 'game'))
tensor([1, 0])
See Also
......@@ -423,10 +717,10 @@ class DGLBaseHeteroGraph(object):
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(self._etypes_invmap[etype], u, v)
rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
return rst.tousertensor()
def predecessors(self, etype, v):
def predecessors(self, v, etype=None):
"""Return the predecessors of node `v` in the graph with the same
edge type.
......@@ -435,10 +729,11 @@ class DGLBaseHeteroGraph(object):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -450,7 +745,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Query who plays Tetris:
>>> g.predecessors(('user', 'plays', 'game'), 0)
>>> g.predecessors(0, ('user', 'plays', 'game'))
tensor([0, 1])
This indicates User #0 (Alice) and User #1 (Bob).
......@@ -459,9 +754,9 @@ class DGLBaseHeteroGraph(object):
--------
successors
"""
return self._graph.predecessors(self._etypes_invmap[etype], v).tousertensor()
return self._graph.predecessors(self.get_etype_id(etype), v).tousertensor()
def successors(self, etype, v):
def successors(self, v, etype=None):
"""Return the successors of node `v` in the graph with the same edge
type.
......@@ -470,10 +765,11 @@ class DGLBaseHeteroGraph(object):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node of source type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -485,7 +781,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Asks which game Alice plays:
>>> g.successors(('user', 'plays', 'game'), 0)
>>> g.successors(0, ('user', 'plays', 'game'))
tensor([0])
This indicates Game #0 (Tetris).
......@@ -494,26 +790,24 @@ class DGLBaseHeteroGraph(object):
--------
predecessors
"""
return self._graph.successors(self._etypes_invmap[etype], v).tousertensor()
return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
def edge_id(self, etype, u, v, force_multi=False):
def edge_id(self, u, v, force_multi=False, etype=None):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`.
Only works if the graph has one edge type. For multiple types,
query with
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The node ID of source type.
v : int
The node ID of destination type.
force_multi : bool
force_multi : bool, optional
If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array.
If True, will always return an array. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -526,33 +820,31 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the edge ID of "Bob plays Tetris"
>>> g.edge_id(('user', 'plays', 'game'), 1, 0)
>>> g.edge_id(1, 0, etype=('user', 'plays', 'game'))
1
See Also
--------
edge_ids
"""
idx = self._graph.edge_id(self._etypes_invmap[etype], u, v)
idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
def edge_ids(self, etype, u, v, force_multi=False):
def edge_ids(self, u, v, force_multi=False, etype=None):
"""Return all edge IDs between source node array `u` and destination
node array `v`.
Only works if the graph has one edge type. For multiple types,
query with
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : list, tensor
The node ID array of source type.
v : list, tensor
The node ID array of destination type.
force_multi : bool
Whether to always treat the graph as a multigraph.
force_multi : bool, optional
Whether to always treat the graph as a multigraph. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -574,7 +866,7 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the edge IDs of "Alice plays Tetris" and "Bob plays Minecraft".
>>> g.edge_ids(('user', 'plays', 'game'), [0, 1], [0, 1])
>>> g.edge_ids([0, 1], [0, 1], etype=('user', 'plays', 'game'))
tensor([0, 2])
See Also
......@@ -583,23 +875,24 @@ class DGLBaseHeteroGraph(object):
"""
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self._etypes_invmap[etype], u, v)
src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
if force_multi or self._graph.is_multigraph():
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
def find_edges(self, etype, eid):
def find_edges(self, eid, etype=None):
"""Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
eid : list, tensor
The edge ID array.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -613,20 +906,18 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the user and game of gameplay #0 and #2:
>>> g.find_edges(('user', 'plays', 'game'), [0, 2])
>>> g.find_edges([0, 2], ('user', 'plays', 'game'))
(tensor([0, 1]), tensor([0, 1]))
"""
eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(self._etypes_invmap[etype], eid)
src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, etype, v, form='uv'):
def in_edges(self, v, form='uv', etype=None):
"""Return the inbound edges of the node(s).
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor
The node(s) of destination type.
form : str, optional
......@@ -635,6 +926,9 @@ class DGLBaseHeteroGraph(object):
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -652,11 +946,11 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the gameplay IDs of game #0 (Tetris)
>>> g.in_edges(('user', 'plays', 'game'), 0, 'eid')
>>> g.in_edges(0, 'eid', ('user', 'plays', 'game'))
tensor([0, 1])
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(self._etypes_invmap[etype], v)
src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
......@@ -666,13 +960,11 @@ class DGLBaseHeteroGraph(object):
else:
raise DGLError('Invalid form:', form)
def out_edges(self, etype, v, form='uv'):
def out_edges(self, v, form='uv', etype=None):
"""Return the outbound edges of the node(s).
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int, list, tensor
The node(s) of source type.
form : str, optional
......@@ -681,6 +973,9 @@ class DGLBaseHeteroGraph(object):
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -698,11 +993,11 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the gameplay IDs of user #0 (Alice)
>>> g.out_edges(('user', 'plays', 'game'), 0, 'eid')
>>> g.out_edges(0, 'eid', ('user', 'plays', 'game'))
tensor([0])
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(self._etypes_invmap[etype], v)
src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
......@@ -712,13 +1007,11 @@ class DGLBaseHeteroGraph(object):
else:
raise DGLError('Invalid form:', form)
def all_edges(self, etype, form='uv', order=None):
def all_edges(self, form='uv', order=None, etype=None):
"""Return all the edges.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
form : str, optional
The return form. Currently support:
......@@ -731,6 +1024,9 @@ class DGLBaseHeteroGraph(object):
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -749,10 +1045,10 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find the user-game pairs for all gameplays:
>>> g.all_edges(('user', 'plays', 'game'), 'uv')
>>> g.all_edges('uv', etype=('user', 'plays', 'game'))
(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1]))
"""
src, dst, eid = self._graph.edges(self._etypes_invmap[etype], order)
src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
......@@ -762,15 +1058,16 @@ class DGLBaseHeteroGraph(object):
else:
raise DGLError('Invalid form:', form)
def in_degree(self, etype, v):
def in_degree(self, v, etype=None):
"""Return the in-degree of node ``v``.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node ID of destination type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -782,27 +1079,28 @@ class DGLBaseHeteroGraph(object):
Examples
--------
Find how many users are playing Game #0 (Tetris):
>>> g.in_degree(('user', 'plays', 'game'), 0)
>>> g.in_degree(0, ('user', 'plays', 'game'))
2
See Also
--------
in_degrees
"""
return self._graph.in_degree(self._etypes_invmap[etype], v)
return self._graph.in_degree(self.get_etype_id(etype), v)
def in_degrees(self, etype, v=ALL):
def in_degrees(self, v=ALL, etype=None):
"""Return the array `d` of in-degrees of the node array `v`.
`d[i]` is the in-degree of node `v[i]`.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor, optional.
The node ID array of destination type. Default is to return the
degrees of all the nodes.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -814,30 +1112,31 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find how many users are playing Game #0 and #1 (Tetris and Minecraft):
>>> g.in_degrees(('user', 'plays', 'game'), [0, 1])
>>> g.in_degrees([0, 1], ('user', 'plays', 'game'))
tensor([2, 2])
See Also
--------
in_degree
"""
etype_idx = self._etypes_invmap[etype]
_, dsttype_idx = self._endpoint_types(etype_idx)
etid = self.get_etype_id(etype)
_, dtid = self._graph.metagraph.find_edge(etid)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(dsttype_idx)))
v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)))
else:
v = utils.toindex(v)
return self._graph.in_degrees(etype_idx, v).tousertensor()
return self._graph.in_degrees(etid, v).tousertensor()
def out_degree(self, etype, v):
def out_degree(self, v, etype=None):
"""Return the out-degree of node `v`.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : int
The node ID of source type.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -847,27 +1146,28 @@ class DGLBaseHeteroGraph(object):
Examples
--------
Find how many games User #0 Alice is playing
>>> g.out_degree(('user', 'plays', 'game'), 0)
>>> g.out_degree(0, ('user', 'plays', 'game'))
1
See Also
--------
out_degrees
"""
return self._graph.out_degree(self._etypes_invmap[etype], v)
return self._graph.out_degree(self.get_etype_id(etype), v)
def out_degrees(self, etype, v=ALL):
def out_degrees(self, v=ALL, etype=None):
"""Return the array `d` of out-degrees of the node array `v`.
`d[i]` is the out-degree of node `v[i]`.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
v : list, tensor
The node ID array of source type. Default is to return the degrees
of all the nodes.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -879,462 +1179,294 @@ class DGLBaseHeteroGraph(object):
The following example uses PyTorch backend.
Find how many games User #0 and #1 (Alice and Bob) are playing
>>> g.out_degrees(('user', 'plays', 'game'), [0, 1])
>>> g.out_degrees([0, 1], ('user', 'plays', 'game'))
tensor([1, 2])
See Also
--------
out_degree
"""
etype_idx = self._etypes_invmap[etype]
srctype_idx, _ = self._endpoint_types(etype_idx)
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(srctype_idx)))
v = utils.toindex(slice(0, self._graph.number_of_nodes(stid)))
else:
v = utils.toindex(v)
return self._graph.out_degrees(etype_idx, v).tousertensor()
def bipartite_from_edge_list(u, v, num_src=None, num_dst=None):
"""Create a bipartite graph component of a heterogeneous graph with a
list of edges.
Parameters
----------
u, v : list[int]
List of source and destination node IDs.
num_src : int, optional
The number of nodes of source type.
By default, the value is the maximum of the source node IDs in the
edge list plus 1.
num_dst : int, optional
The number of nodes of destination type.
By default, the value is the maximum of the destination node IDs in
the edge list plus 1.
"""
num_src = num_src or (max(u) + 1)
num_dst = num_dst or (max(v) + 1)
u = utils.toindex(u)
v = utils.toindex(v)
return heterograph_index.create_bipartite_from_coo(num_src, num_dst, u, v)
return self._graph.out_degrees(etid, v).tousertensor()
def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
"""Internal function to create a subgraph."""
node_frames = [
FrameRef(Frame(
self._node_frames[i][induced_nodes_of_ntype],
num_rows=len(induced_nodes_of_ntype)))
for i, induced_nodes_of_ntype in enumerate(induced_nodes)]
edge_frames = [
FrameRef(Frame(
self._edge_frames[i][induced_edges_of_etype],
num_rows=len(induced_edges_of_etype)))
for i, induced_edges_of_etype in enumerate(induced_edges)]
hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
hsg.is_subgraph = True
for ntype, induced_nid in zip(self.ntypes, induced_nodes):
hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
hsg.edges[etype].data[EID] = induced_eid.tousertensor()
return hsg
def bipartite_from_scipy(spmat, with_edge_id=False):
"""Create a bipartite graph component of a heterogeneous graph with a
scipy sparse matrix.
def subgraph(self, nodes):
"""Return the subgraph induced on given nodes.
Parameters
----------
spmat : scipy sparse matrix
The bipartite graph matrix whose rows represent sources and columns
represent destinations.
with_edge_id : bool
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
"""
spmat = spmat.tocsr()
num_src, num_dst = spmat.shape
indptr = utils.toindex(spmat.indptr)
indices = utils.toindex(spmat.indices)
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))))
return heterograph_index.create_bipartite_from_csr(num_src, num_dst, indptr, indices, data)
The metagraph of the returned subgraph is the same as the parent graph.
Features are copied from the original graph.
class DGLHeteroGraph(DGLBaseHeteroGraph):
"""Base heterogeneous graph class.
Examples
--------
TBD
A Heterogeneous graph is defined as a graph with node types and edge
types.
Parameters
----------
nodes : dict[str, list or iterable]
A dictionary of node types to node ID array to construct
subgraph.
All nodes must exist in the graph.
If two edges share the same edge type, then their source nodes, as well
as their destination nodes, also have the same type (the source node
types don't have to be the same as the destination node types).
Parameters
----------
graph_data :
The graph data. It can be one of the followings:
* (nx.MultiDiGraph, dict[str, list[tuple[int, int]]])
* (nx.MultiDiGraph, dict[str, scipy.sparse.matrix])
The first element is the metagraph of the heterogeneous graph, as a
networkx directed graph. Its nodes represent the node types, and
its edges represent the edge types. The edge type name should be
stored as edge keys.
The second element is a mapping from edge type to edge list. The
edge list can be either a list of (u, v) pairs, or a scipy sparse
matrix whose rows represents sources and columns represents
destinations. The edges will be added in the (source, destination)
order.
node_frames : dict[str, dict[str, Tensor]]
The node frames for each node type
edge_frames : dict[str, dict[str, Tensor]]
The edge frames for each edge type
multigraph : bool
Whether the heterogeneous graph is a multigraph.
readonly : bool
Whether the heterogeneous graph is readonly.
Examples
--------
Suppose that we want to construct the following heterogeneous graph:
.. graphviz::
digraph G {
Alice -> Bob [label=follows]
Bob -> Carol [label=follows]
Alice -> Tetris [label=plays]
Bob -> Tetris [label=plays]
Bob -> Minecraft [label=plays]
Carol -> Minecraft [label=plays]
Nintendo -> Tetris [label=develops]
Mojang -> Minecraft [label=develops]
{rank=source; Alice; Bob; Carol}
{rank=sink; Nintendo; Mojang}
}
One can analyze the graph and figure out the metagraph as follows:
.. graphviz::
Returns
-------
G : DGLHeteroGraph
The subgraph.
The nodes are relabeled so that node `i` of type `t` in the
subgraph is mapped to the ``nodes[i]`` of type `t` in the
original graph.
The edges are also relabeled.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `dgl.NID` and `dgl.EID` node/edge features of the
subgraph.
"""
induced_nodes = [utils.toindex(nodes.get(ntype, [])) for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes)
induced_edges = sgi.induced_edges
digraph G {
User -> User [label=follows]
User -> Game [label=plays]
Developer -> Game [label=develops]
}
return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
Suppose that one maps the users, games and developers to the following
IDs:
def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges.
User name Alice Bob Carol
User ID 0 1 2
The metagraph of the returned subgraph is the same as the parent graph.
Game name Tetris Minecraft
Game ID 0 1
Features are copied from the original graph.
Developer name Nintendo Mojang
Developer ID 0 1
Examples
--------
TBD
One can construct the graph as follows:
Parameters
----------
edges : dict[etype, list or iterable]
A dictionary of edge types to edge ID array to construct
subgraph.
All edges must exist in the subgraph.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
>>> mg = nx.MultiDiGraph([('user', 'user', 'follows'),
... ('user', 'game', 'plays'),
... ('developer', 'game', 'develops')])
>>> g = DGLHeteroGraph(
... mg, {
... 'follows': [(0, 1), (1, 2)],
... 'plays': [(0, 0), (1, 0), (1, 1), (2, 1)],
... 'develops': [(0, 0), (1, 1)]})
Returns
-------
G : DGLHeteroGraph
The subgraph.
The edges are relabeled so that edge `i` of type `t` in the
subgraph is mapped to the ``edges[i]`` of type `t` in the
original graph.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `dgl.NID` and `dgl.EID` node/edge features of the
subgraph.
"""
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [
utils.toindex(edges.get(canonical_etype, []))
for canonical_etype in self.canonical_etypes]
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes
Then one can query the graph structure as follows:
return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
>>> g['user'].number_of_nodes()
3
>>> g['plays'].number_of_edges()
4
>>> g['develops'].out_degrees() # out-degrees of source nodes of 'develops' relation
tensor([1, 1])
>>> g['develops'].in_edges(0) # in-edges of destination node 0 of 'develops' relation
(tensor([0]), tensor([0]))
def node_type_subgraph(self, ntypes):
"""Return the subgraph induced on given node types.
Notes
-----
Currently, all heterogeneous graphs are readonly.
"""
# pylint: disable=unused-argument
def __init__(
self,
graph_data=None,
node_frames=None,
edge_frames=None,
multigraph=None,
readonly=True,
_view_ntype_idx=None,
_view_etype_idx=None):
assert readonly, "Only readonly heterogeneous graphs are supported"
The metagraph of the returned subgraph is the subgraph of the original metagraph
induced from the node types.
# Creating a view of another graph?
if isinstance(graph_data, DGLHeteroGraph):
super(DGLHeteroGraph, self).__init__(
graph_data._graph, graph_data._ntypes, graph_data._etypes,
graph_data._ntypes_invmap, graph_data._etypes_invmap,
graph_data._view_ntype_idx, graph_data._view_etype_idx)
self._node_frames = graph_data._node_frames
self._edge_frames = graph_data._edge_frames
self._msg_frames = graph_data._msg_frames
self._msg_indices = graph_data._msg_indices
self._view_ntype_idx = _view_ntype_idx
self._view_etype_idx = _view_etype_idx
return
Features are shared with the original graph.
if isinstance(graph_data, tuple):
metagraph, edges_by_type = graph_data
if not isinstance(metagraph, nx.MultiDiGraph):
raise TypeError('Metagraph should be networkx.MultiDiGraph')
# create metagraph graph index
srctypes, dsttypes, etypes = [], [], []
ntypes = []
ntypes_invmap = {}
etypes_invmap = {}
for srctype, dsttype, etype in metagraph.edges(keys=True):
srctypes.append(srctype)
dsttypes.append(dsttype)
etypes_invmap[(srctype, etype, dsttype)] = len(etypes_invmap)
etypes.append((srctype, etype, dsttype))
if srctype not in ntypes_invmap:
ntypes_invmap[srctype] = len(ntypes_invmap)
ntypes.append(srctype)
if dsttype not in ntypes_invmap:
ntypes_invmap[dsttype] = len(ntypes_invmap)
ntypes.append(dsttype)
srctypes = [ntypes_invmap[srctype] for srctype in srctypes]
dsttypes = [ntypes_invmap[dsttype] for dsttype in dsttypes]
metagraph_index = graph_index.create_graph_index(
list(zip(srctypes, dsttypes)), None, True) # metagraph is always immutable
# create base bipartites
bipartites = []
num_nodes = defaultdict(int)
# count the number of nodes for each type
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
num_src, num_dst = edges.shape
elif isinstance(edges, list):
u, v = zip(*edges)
num_src = max(u) + 1
num_dst = max(v) + 1
else:
raise TypeError('unknown edge list type %s' % type(edges))
num_nodes[srctype] = max(num_nodes[srctype], num_src)
num_nodes[dsttype] = max(num_nodes[dsttype], num_dst)
# create actual objects
for etype_triplet in etypes:
srctype, etype, dsttype = etype_triplet
edges = edges_by_type[etype_triplet]
if ssp.issparse(edges):
bipartite = bipartite_from_scipy(edges)
elif isinstance(edges, list):
u, v = zip(*edges)
bipartite = bipartite_from_edge_list(
u, v, num_nodes[srctype], num_nodes[dsttype])
bipartites.append(bipartite)
hg_index = heterograph_index.create_heterograph(metagraph_index, bipartites)
super(DGLHeteroGraph, self).__init__(hg_index, ntypes, etypes)
else:
raise TypeError('Unrecognized graph data type %s' % type(graph_data))
Examples
--------
TBD
# node and edge frame
if node_frames is None:
self._node_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_nodes(i)))
for i in range(len(self._ntypes))]
else:
self._node_frames = node_frames
Parameters
----------
ntypes : list[str]
The node types
if edge_frames is None:
self._edge_frames = [
FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
for i in range(len(self._etypes))]
else:
self._edge_frames = edge_frames
Returns
-------
G : DGLHeteroGraph
The subgraph.
"""
rel_graphs = []
meta_edges = []
induced_etypes = []
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = []
# message indicators
self._msg_indices = [None] * len(self._etypes)
self._msg_frames = []
ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
for i in range(len(self._etypes)):
frame = FrameRef(Frame(num_rows=self._graph.number_of_edges(i)))
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
srctype = self._ntypes[srctype_id[i]]
dsttype = self._ntypes[dsttype_id[i]]
def _create_view(self, ntype_idx, etype_idx):
return DGLHeteroGraph(
graph_data=self, _view_ntype_idx=ntype_idx, _view_etype_idx=etype_idx)
if srctype in ntypes and dsttype in ntypes:
meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
rel_graphs.append(self._graph.get_relation_graph(i))
induced_etypes.append(self.etypes[i])
edge_frames.append(self._edge_frames[i])
def _get_msg_index(self):
if self._msg_indices[self._current_etype_idx] is None:
self._msg_indices[self._current_etype_idx] = utils.zero_index(
size=self._graph.number_of_edges(self._current_etype_idx))
return self._msg_indices[self._current_etype_idx]
metagraph = graph_index.from_edge_list(meta_edges, True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
return hg
def _set_msg_index(self, index):
self._msg_indices[self._current_etype_idx] = index
def edge_type_subgraph(self, etypes):
"""Return the subgraph induced on given edge types.
def __getitem__(self, key):
if key in self._etypes_invmap:
return self._create_view(None, self._etypes_invmap[key])
else:
raise KeyError(key)
The metagraph of the returned subgraph is the subgraph of the original metagraph
induced from the edge types.
@property
def _node_frame(self):
# overrides DGLGraph._node_frame
return self._node_frames[self._current_ntype_idx]
Features are shared with the original graph.
@property
def _edge_frame(self):
# overrides DGLGraph._edge_frame
return self._edge_frames[self._current_etype_idx]
Examples
--------
TBD
@property
def _src_frame(self):
# overrides DGLGraph._src_frame
return self._node_frames[self._current_srctype_idx]
Parameters
----------
etypes : list[str or tuple]
The edge types
@property
def _dst_frame(self):
# overrides DGLGraph._dst_frame
return self._node_frames[self._current_dsttype_idx]
Returns
-------
G : DGLHeteroGraph
The subgraph.
"""
etype_ids = [self.get_etype_id(etype) for etype in etypes]
meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids))
rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy()
induced_ntype_ids = list(set(meta_src) | set(meta_dst))
mapped_meta_src = [induced_ntype_ids[v] for v in meta_src]
mapped_meta_dst = [induced_ntype_ids[v] for v in meta_dst]
node_frames = [self._node_frames[i] for i in induced_ntype_ids]
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in induced_ntype_ids]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
def adjacency_matrix(self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None):
"""Return the adjacency matrix of edges of the given edge type.
@property
def _msg_frame(self):
# overrides DGLGraph._msg_frame
return self._msg_frames[self._current_etype_idx]
By default, a row of returned adjacency matrix represents the
destination of an edge and the column represents the source.
def add_nodes(self, node_type, num, data=None):
"""Add multiple new nodes of the same node type
When transpose is True, a row represents the source and a column
represents a destination.
Parameters
----------
node_type : str
Type of the added nodes. Must appear in the metagraph.
num : int
Number of nodes to be added.
data : dict, optional
Feature data of the added nodes.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
transpose : bool, optional (default=False)
A flag to transpose the returned adjacency matrix.
ctx : context, optional (default=cpu)
The context of returned adjacency matrix.
scipy_fmt : str, optional (default=None)
If specified, return a scipy sparse matrix in the given format.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
>>> g['game'].number_of_nodes()
2
>>> g.add_nodes(3, 'game') # add 3 new games
>>> g['game'].number_of_nodes()
5
Returns
-------
SparseTensor or scipy.sparse.spmatrix
Adjacency matrix.
"""
pass
etid = self.get_etype_id(etype)
if scipy_fmt is None:
return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
else:
return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
def add_edge(self, etype, u, v, data=None):
"""Add an edge of ``etype`` between u of the source node type, and v
of the destination node type..
# Alias of ``adjacency_matrix``
adj = adjacency_matrix
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
u : int
The source node ID of type ``utype``. Must exist in the graph.
v : int
The destination node ID of type ``vtype``. Must exist in the
graph.
data : dict, optional
Feature data of the added edge.
def incidence_matrix(self, typestr, ctx=F.cpu(), etype=None):
"""Return the incidence matrix representation of edges with the given
edge type.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
>>> g['plays'].number_of_edges()
4
>>> g.add_edge(2, 0, 'plays')
>>> g['plays'].number_of_edges()
5
"""
pass
There are three types of an incidence matrix :math:`I`:
def add_edges(self, u, v, etype, data=None):
"""Add multiple edges of ``etype`` between list of source nodes ``u``
and list of destination nodes ``v`` of type ``vtype``. A single edge
is added between every pair of ``u[i]`` and ``v[i]``.
* ``in``:
Parameters
----------
u : list, tensor
The source node IDs of type ``utype``. Must exist in the graph.
v : list, tensor
The destination node IDs of type ``vtype``. Must exist in the
graph.
etype : (str, str, str)
The source-edge-destination type triplet
data : dict, optional
Feature data of the added edge.
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
(or :math:`v` is the dst node of :math:`e`);
- :math:`I[v, e] = 0` otherwise.
Examples
--------
The variable ``g`` is constructed from the example in
DGLBaseHeteroGraph.
* ``out``:
>>> g['plays'].number_of_edges()
4
>>> g.add_edges([0, 2], [1, 0], 'plays')
>>> g['plays'].number_of_edges()
6
"""
pass
def from_networkx(
self,
nx_graph,
node_type_attr_name='type',
edge_type_attr_name='type',
node_id_attr_name='id',
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None):
"""Convert from networkx graph.
The networkx graph must satisfy the metagraph. That is, for any
edge in the networkx graph, the source/destination node type must
be the same as the source/destination node of the edge type in
the metagraph. An error will be raised otherwise.
- :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
(or :math:`v` is the src node of :math:`e`);
- :math:`I[v, e] = 0` otherwise.
* ``both`` (only if source and destination node type are the same):
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
- :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
- :math:`I[v, e] = 0` otherwise (including self-loop).
Parameters
----------
nx_graph : networkx.DiGraph
The networkx graph.
node_type_attr_name : str
The node attribute name for the node type.
The attribute contents must be strings.
edge_type_attr_name : str
The edge attribute name for the edge type.
The attribute contents must be strings.
node_id_attr_name : str
The node attribute name for node type-specific IDs.
The attribute contents must be integers.
If the IDs of the same type are not consecutive integers, its
nodes will be relabeled using consecutive integers. The new
node ordering will inherit that of the sorted IDs.
edge_id_attr_name : str or None
The edge attribute name for edge type-specific IDs.
The attribute contents must be integers.
If the IDs of the same type are not consecutive integers, its
nodes will be relabeled using consecutive integers. The new
node ordering will inherit that of the sorted IDs.
If None is provided, the edge order would be arbitrary.
node_attrs : iterable of str, optional
The node attributes whose data would be copied.
edge_attrs : iterable of str, optional
The edge attributes whose data would be copied.
typestr : str
Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu)
The context of returned incidence matrix.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
SparseTensor
The incidence matrix.
"""
pass
etid = self.get_etype_id(etype)
return self._graph.incidence_matrix(etid, typestr, ctx)[0]
# Alias of ``incidence_matrix``
inc = incidence_matrix
def node_attr_schemes(self, ntype):
#################################################################
# Features
#################################################################
def node_attr_schemes(self, ntype=None):
"""Return the node feature schemes.
Each feature scheme is a named tuple that stores the shape and data type
......@@ -1342,8 +1474,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
ntype : str
The node type
ntype : str, optional
The node type. Could be omitted if there is only one node
type in the graph. Error will be raised otherwise.
(Default: None)
Returns
-------
......@@ -1354,13 +1488,13 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
--------
The following uses PyTorch backend.
>>> g.ndata['user']['h'] = torch.randn(3, 4)
>>> g.nodes['user'].data['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)}
"""
return self._node_frames[self._ntypes_invmap[ntype]].schemes
return self._node_frames[self.get_ntype_id(ntype)].schemes
def edge_attr_schemes(self, etype):
def edge_attr_schemes(self, etype=None):
"""Return the edge feature schemes.
Each feature scheme is a named tuple that stores the shape and data type
......@@ -1368,8 +1502,9 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
......@@ -1380,60 +1515,75 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
--------
The following uses PyTorch backend.
>>> g.edata['user', 'plays', 'game']['h'] = torch.randn(4, 4)
>>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
>>> g.edge_attr_schemes(('user', 'plays', 'game'))
{'h': Scheme(shape=(4,), dtype=torch.float32)}
"""
return self._edge_frames[self._etypes_invmap[etype]].schemes
return self._edge_frames[self.get_etype_id(etype)].schemes
@property
def nodes(self):
"""Return a node view that can used to set/get feature data of a
single node type.
def set_n_initializer(self, initializer, field=None, ntype=None):
"""Set the initializer for empty node features.
Examples
--------
To set features of User #0 and #2 in a heterogeneous graph:
>>> g.nodes['user'][[0, 2]].data['h'] = torch.zeros(2, 5)
"""
return HeteroNodeView(self)
Initializer is a callable that returns a tensor given the shape, data type
and device context.
@property
def ndata(self):
"""Return the data view of all the nodes of a single node type.
When a subset of the nodes are assigned a new feature, initializer is
used to create feature for rest of the nodes.
Parameters
----------
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
ntype : str, optional
The node type. Could be omitted if there is only one node
type in the graph. Error will be raised otherwise.
(Default: None)
Examples
--------
To set features of games in a heterogeneous graph:
>>> g.ndata['game']['h'] = torch.zeros(2, 5)
"""
return HeteroNodeDataView(self)
@property
def edges(self):
"""Return an edges view that can used to set/get feature data of a
single edge type.
Note
-----
User defined initializer must follow the signature of
:func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
Examples
--------
To set features of gameplays #1 (Bob -> Tetris) and #3 (Carol ->
Minecraft) in a heterogeneous graph:
>>> g.edges['user', 'plays', 'game'][[1, 3]].data['h'] = torch.zeros(2, 5)
"""
return HeteroEdgeView(self)
ntid = self.get_ntype_id(ntype)
self._node_frames[ntid].set_initializer(initializer, field)
@property
def edata(self):
"""Return the data view of all the edges of a single edge type.
def set_e_initializer(self, initializer, field=None, etype=None):
"""Set the initializer for empty edge features.
Examples
--------
>>> g.edata['developer', 'develops', 'game']['h'] = torch.zeros(2, 5)
Initializer is a callable that returns a tensor given the shape, data
type and device context.
When a subset of the edges are assigned a new feature, initializer is
used to create feature for rest of the edges.
Parameters
----------
initializer : callable
The initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Note
-----
User defined initializer must follow the signature of
:func:`dgl.init.base_initializer() <dgl.init.base_initializer>`
"""
return HeteroEdgeDataView(self)
etid = self.get_etype_id(etype)
self._edge_frames[etid].set_initializer(initializer, field)
def set_n_repr(self, ntype, data, u=ALL, inplace=False):
"""Set node(s) representation of a single node type.
def _set_n_repr(self, ntid, u, data, inplace=False):
"""Internal API to set node features.
`data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
......@@ -1445,18 +1595,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
ntype : str
The node type
data : dict of tensor
Node representation.
ntid : int
Node type id.
u : node, container or tensor
The node(s).
inplace : bool
data : dict of tensor
Node representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
"""
ntype = self._ntypes_invmap[ntype]
if is_all(u):
num_nodes = self._graph.number_of_nodes(ntype)
num_nodes = self._graph.number_of_nodes(ntid)
else:
u = utils.toindex(u)
num_nodes = len(u)
......@@ -1468,19 +1618,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
if is_all(u):
for key, val in data.items():
self._node_frames[ntype][key] = val
self._node_frames[ntid][key] = val
else:
self._node_frames[ntype].update_rows(u, data, inplace=inplace)
self._node_frames[ntid].update_rows(u, data, inplace=inplace)
def get_n_repr(self, ntype, u=ALL):
def _get_n_repr(self, ntid, u):
"""Get node(s) representation of a single node type.
The returned feature tensor batches multiple node features on the first dimension.
Parameters
----------
ntype : str
The node type
ntid : int
Node type id.
u : node, container or tensor
The node(s).
......@@ -1489,22 +1639,19 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict
Representation dict from feature name to feature tensor.
"""
if len(self.node_attr_schemes(ntype)) == 0:
return dict()
ntype_idx = self._ntypes_invmap[ntype]
if is_all(u):
return dict(self._node_frames[ntype_idx])
return dict(self._node_frames[ntid])
else:
u = utils.toindex(u)
return self._node_frames[ntype_idx].select_rows(u)
return self._node_frames[ntid].select_rows(u)
def pop_n_repr(self, ntype, key):
"""Get and remove the specified node repr of a given node type.
def _pop_n_repr(self, ntid, key):
"""Internal API to get and remove the specified node feature.
Parameters
----------
ntype : str
The node type
ntid : int
Node type id.
key : str
The attribute name.
......@@ -1513,11 +1660,10 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor
The popped representation
"""
ntype = self._ntypes_invmap[ntype]
return self._node_frames[ntype].pop(key)
return self._node_frames[ntid].pop(key)
def set_e_repr(self, etype, data, edges=ALL, inplace=False):
"""Set edge(s) representation of a single edge type.
def _set_e_repr(self, etid, edges, data, inplace=False):
"""Internal API to set edge(s) features.
`data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
......@@ -1528,10 +1674,8 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
data : tensor or dict of tensor
Edge representation.
etid : int
Edge type id.
edges : edges
Edges can be either
......@@ -1540,10 +1684,12 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
* A tensor of edge ids of the given type.
The default value is all the edges.
inplace : bool
data : tensor or dict of tensor
Edge representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
"""
etype_idx = self._etypes_invmap[etype]
# parse argument
if is_all(edges):
eid = ALL
......@@ -1552,7 +1698,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
_, _, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
......@@ -1562,7 +1708,7 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
' Got "%s" instead.' % type(data))
if is_all(eid):
num_edges = self._graph.number_of_edges(etype_idx)
num_edges = self._graph.number_of_edges(etid)
else:
eid = utils.toindex(eid)
num_edges = len(eid)
......@@ -1575,18 +1721,18 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
if is_all(eid):
# update column
for key, val in data.items():
self._edge_frames[etype_idx][key] = val
self._edge_frames[etid][key] = val
else:
# update row
self._edge_frames[etype_idx].update_rows(eid, data, inplace=inplace)
self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
def get_e_repr(self, etype, edges=ALL):
"""Get edge(s) representation.
def _get_e_repr(self, etid, edges):
"""Internal API to get edge features.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
etid : int
Edge type id.
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
......@@ -1596,9 +1742,6 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
dict
Representation dict
"""
etype_idx = self._etypes_invmap[etype]
if len(self.edge_attr_schemes(etype)) == 0:
return dict()
# parse argument
if is_all(edges):
eid = ALL
......@@ -1607,23 +1750,23 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etype_idx, u, v)
_, _, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
if is_all(eid):
return dict(self._edge_frames[etype_idx])
return dict(self._edge_frames[etid])
else:
eid = utils.toindex(eid)
return self._edge_frames[etype_idx].select_rows(eid)
return self._edge_frames[etid].select_rows(eid)
def pop_e_repr(self, etype, key):
def _pop_e_repr(self, etid, key):
"""Get and remove the specified edge repr of a single edge type.
Parameters
----------
etype : (str, str, str)
The source-edge-destination type triplet
etid : int
Edge type id.
key : str
The attribute name.
......@@ -1632,136 +1775,51 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Tensor
The popped representation
"""
etype = self._etypes_invmap[etype]
self._edge_frames[etype].pop(key)
self._edge_frames[etid].pop(key)
def register_message_func(self, func):
"""Register global message function for each edge type provided.
#################################################################
# Message passing
#################################################################
def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
"""Apply the function on the nodes with the same type to update their
features.
Once registered, ``func`` will be used as the default
message function in message passing operations, including
:func:`send`, :func:`send_and_recv`, :func:`pull`,
:func:`push`, :func:`update_all`.
If None is provided for ``func``, nothing will happen.
Parameters
----------
func : callable
Message function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
See Also
--------
send
send_and_recv
pull
push
update_all
"""
raise NotImplementedError
def register_reduce_func(self, func):
"""Register global message reduce function for each edge type provided.
Once registered, ``func`` will be used as the default
message reduce function in message passing operations, including
:func:`recv`, :func:`send_and_recv`, :func:`push`, :func:`pull`,
:func:`update_all`.
Parameters
----------
func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
See Also
--------
recv
send_and_recv
push
pull
update_all
"""
raise NotImplementedError
def register_apply_node_func(self, func):
"""Register global node apply function for each node type provided.
Once registered, ``func`` will be used as the default apply
node function. Related operations include :func:`apply_nodes`,
:func:`recv`, :func:`send_and_recv`, :func:`push`, :func:`pull`,
:func:`update_all`.
Parameters
----------
func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
See Also
--------
apply_nodes
register_apply_edge_func
"""
raise NotImplementedError
def register_apply_edge_func(self, func):
"""Register global edge apply function for each edge type provided.
Once registered, ``func`` will be used as the default apply
edge function in :func:`apply_edges`.
Parameters
----------
func : callable
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
See Also
--------
apply_edges
register_apply_node_func
"""
raise NotImplementedError
def apply_nodes(self, func, v=ALL, inplace=False):
"""Apply the function on the nodes with the same type to update their
features.
If None is provided for ``func``, nothing will happen.
Parameters
----------
func : dict[str, callable] or None
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
v : dict[str, int or iterable of int or tensor], optional
The (type-specific) node (ids) on which to apply ``func``.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
v : int or iterable of int or tensor, optional
The (type-specific) node (ids) on which to apply ``func``.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
Examples
--------
>>> g.ndata['user']['h'] = torch.ones(3, 5)
>>> g.apply_nodes({'user': lambda nodes: {'h': nodes.data['h'] * 2}})
>>> g.ndata['user']['h']
>>> g.nodes['user'].data['h'] = torch.ones(3, 5)
>>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
>>> g.nodes['user'].data['h']
tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]])
"""
for ntype, nfunc in func.items():
if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
else:
v_ntype = utils.toindex(v[ntype])
with ir.prog() as prog:
scheduler.schedule_apply_nodes(
graph=self._create_view(self._ntypes_invmap[ntype], None),
v=v_ntype,
apply_func=nfunc,
inplace=inplace)
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, inplace=False):
ntid = self.get_ntype_id(ntype)
if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)))
else:
v_ntype = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
inplace=inplace)
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
"""Apply the function on the edges with the same type to update their
features.
......@@ -1769,52 +1827,50 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
func : dict[(str, str, str), callable] or None
func : callable or None
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
edges : dict[(str, str, str), any valid edge specification], optional
edges : edges data, optional
Edges on which to apply ``func``. See :func:`send` for valid
edge specification.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Examples
--------
>>> g.edata['user', 'plays', 'game']['h'] = torch.ones(4, 5)
>>> g.apply_edges(
... {('user', 'plays', 'game'): lambda edges: {'h': edges.data['h'] * 2}})
>>> g.edata['user', 'plays', 'game']['h']
>>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
>>> g.edges[('user', 'plays', 'game')].data['h']
tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]])
"""
for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(
AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid, func, inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
"""Group the edges by nodes and apply the function of the grouped
edges to update their features. The edges are of the same edge type
(hence having the same source and destination node type).
......@@ -1823,47 +1879,47 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
----------
group_by : str
Specify how to group edges. Expected to be either 'src' or 'dst'
func : dict[(str, str, str), callable]
func : callable
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should
be (bucket_size, degrees, *feature_shape), and
return the dict with values of the same shapes.
edges : dict[(str, str, str), valid edges type], optional
edges : edges data, optional
Edges on which to group and apply ``func``. See :func:`send` for valid
edges type. Default is all the edges.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst")
for etype, efunc in func.items():
etype_idx = self._etypes_invmap[etype]
if is_all(edges):
u, v, _ = self._graph.edges(etype_idx)
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etype_idx, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etype_idx, eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(
graph=self._create_view(None, etype_idx),
u=u,
v=v,
eid=eid,
apply_func=efunc,
group_by=group_by,
inplace=inplace)
Runtime.run(prog)
def send(self, edges=ALL, message_func=None):
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, _ = self._graph.edges(etid)
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(
AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid,
func, group_by,
inplace=inplace)
Runtime.run(prog)
def send(self, edges, message_func, etype=None):
"""Send messages along the given edges with the same edge type.
``edges`` can be any of the following types:
......@@ -1903,101 +1959,195 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
On multigraphs, if :math:`u` and :math:`v` are specified, then the messages will be sent
along all edges between :math:`u` and :math:`v`.
"""
assert not utils.is_dict_like(message_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(self._current_etype_idx)))
u, v, _ = self._graph.edges(self._current_etype_idx)
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
u, v, _ = self._graph.edges(etid)
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
u, v, _ = self._graph.find_edges(etid, eid)
if len(eid) == 0:
# no edge to be triggered
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
message_func=message_func)
scheduler.schedule_send(
AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid,
message_func)
Runtime.run(prog)
def recv(self,
v=ALL,
reduce_func=None,
v,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
Optionally, apply a function to update the node features after receive.
It calculates:
.. math::
h_v^{new} = \sigma(\sum_{u\in\mathcal{N}_{t}(v)}m_{uv})
where :math:`\mathcal{N}_t(v)` defines the predecessors of node(s) ``v`` connected by
edge type :math:`t`, and :math:`m_{uv}` is the message on edge (u,v).
* ``reduce_func`` specifies :math:`\sum`.
* ``apply_func`` specifies :math:`\sigma`.
Other notes:
* `reduce_func` will be skipped for nodes with no incoming message.
* If all ``v`` have no incoming message, this will downgrade to an :func:`apply_nodes`.
* If some ``v`` have no incoming message, their new feature value will be calculated
by the column initializer (see :func:`set_n_initializer`). The feature shapes and
dtypes will be inferred.
* The node features will be updated by the result of the ``reduce_func``.
* Messages are consumed once received.
* The provided UDF maybe called multiple times so it is recommended to provide
function with no side effect.
* The cross-type reducer will check the output field of each per-type reducer
and aggregate those who write to the **same** fields. If None is provided,
the default behavior is overwrite.
The node features will be updated by the result of the ``reduce_func``.
Examples
--------
Only one type of nodes in the graph:
Messages are consumed once received.
>>> import dgl.function as fn
>>> G.recv(v, fn.sum('m', 'h'))
The provided UDF maybe called multiple times so it is recommended to provide
function with no side effect.
Specify reducer for each type and use cross-type reducer to accum results.
Only works if the graph has one edge type. For multiple types,
use
>>> import dgl.function as fn
>>> G.recv(v,
>>> ... {'plays' : fn.sum('m', 'h'), 'develops' : fn.max('m', 'h')},
>>> ... 'sum')
.. code::
Error will be thrown if per-type reducers cannot determine the node type of v.
g['edgetype'].recv(v, reduce_func, apply_node_func, inplace)
>>> import dgl.function as fn
>>> # ambiguous, v is of both 'user' and 'game' types
>>> G.recv(v,
>>> ... {('user', 'follows', 'user') : fn.sum('m', 'h'),
>>> ... ('user', 'plays', 'game') : fn.max('m', 'h')},
>>> ... 'sum')
Parameters
----------
v : int, container or tensor, optional
v : int, container or tensor
The node(s) to be updated. Default is receiving all the nodes.
reduce_func : callable, optional
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
assert not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert reduce_func is not None
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(v):
v = F.arange(0, self._graph.number_of_nodes(self._current_dsttype_idx))
v = F.arange(0, self.number_of_nodes(dtid))
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
# no vertex to be triggered.
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self,
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
v, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def multi_recv(self, v, reducer_dict, cross_reducer, apply_func=None, inplace=False):
r"""Receive messages from multiple edge types and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}m_{uv})
* ``per_type_reducer`` is a dictionary from edge type to reduce functions
:math:`\sum_{u\in\mathcal{N_t}(v)}` of each type.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
reduce_dict : dict of callable
Reduce function per edge type. The function should be
a :mod:`Node UDF <dgl.udf>`.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, reducer_dict)
ntid = self.get_ntype_id(ntype)
if is_all(v):
v = F.arange(0, self.number_of_nodes(ntid))
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
return
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
with ir.prog() as prog:
for ety, args in reducer_dict.items():
outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
args = pad_tuple(args, 2)
if args is None:
raise DGLError('Invalid per-type arguments. Should be either '
'(1) reduce_func or (2) (reduce_func, apply_func)')
rfunc, afunc = args
etid = self.get_etype_id(ety)
stid, dtid = self._graph.metagraph.find_edge(etid)
scheduler.schedule_recv(AdaptedHeteroGraph(self, stid, dtid, etid),
v, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[ntid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, v, ntype, inplace)
def send_and_recv(self,
edges,
message_func=None,
reduce_func=None,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Send messages along edges with the same edge type, and let destinations
receive them.
......@@ -2021,53 +2171,128 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid
edges type.
message_func : callable, optional
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(self._current_etype_idx, u, v)
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(self._current_etype_idx, eid)
u, v, _ = self._graph.find_edges(etid, eid)
if len(u) == 0:
# no edges to be triggered
return
with ir.prog() as prog:
scheduler.schedule_snr(graph=self,
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
(u, v, eid),
message_func, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def multi_send_and_recv(self, etype_dict, cross_reducer, apply_func=None, inplace=False):
r"""Send and receive messages along multiple edge types and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
h_u, h_v, h_{uv}))
* ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
normal ``send_and_recv``.
* :math:`\mathcal{N}_t(v)` is defined by the edges given for type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict of callable
``send_and_recv`` arguments per edge type.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
all_vs = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 4)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(edges, msg_func, reduce_func, [apply_func])')
edges, mfunc, rfunc, afunc = args
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
all_vs.append(v)
if len(u) == 0:
# no edges to be triggered
continue
scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
(u, v, eid),
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
self.apply_nodes(apply_func, dstnodes, ntype, inplace)
def pull(self,
v,
message_func=None,
reduce_func=None,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Pull messages from the node(s)' predecessors and then update their features.
......@@ -2090,40 +2315,108 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
----------
v : int, container or tensor, optional
The node(s) to be updated. Default is receiving all the nodes.
message_func : callable, optional
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
# only one type of edges
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
v = utils.toindex(v)
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self,
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
v,
message_func, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def multi_pull(self, v, etype_dict, cross_reducer, apply_func=None, inplace=False):
r"""Pull and receive messages of the given nodes along multiple edge types
and perform aggregation.
It calculates:
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
h_u, h_v, h_{uv}))
* ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
normal ``pull``.
* :math:`\mathcal{N}_t(v)` is the set of predecessors of ``v`` connected by edge
type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Examples
--------
TBD
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict of callable
``pull`` arguments per edge type.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
v = utils.toindex(v)
if len(v) == 0:
return
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(msg_func, reduce_func, [apply_func])')
mfunc, rfunc, afunc = args
scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
v,
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, v, ntype, inplace)
def push(self,
u,
message_func=None,
reduce_func=None,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Send message from the node(s) to their successors and update them.
......@@ -2140,41 +2433,40 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
----------
u : int, container or tensor
The node(s) to push messages out.
message_func : callable, optional
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
# only one type of edges
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
u = utils.toindex(u)
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self,
u=u,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
u,
message_func, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def update_all(self,
message_func=None,
reduce_func=None,
apply_node_func=None):
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Send messages through all edges and update all nodes.
Optionally, apply a function to update the node features after receive.
......@@ -2192,229 +2484,235 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
Parameters
----------
message_func : callable, optional
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
"""
assert not utils.is_dict_like(message_func) and \
not utils.is_dict_like(reduce_func) and \
not utils.is_dict_like(apply_node_func), \
"multiple-type message passing is not implemented"
assert message_func is not None
assert reduce_func is not None
# only one type of edges
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
message_func, reduce_func,
apply_node_func)
Runtime.run(prog)
def prop_nodes(self,
nodes_generator,
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Node propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
def prop_edges(self,
edges_generator,
message_func=None,
reduce_func=None,
apply_node_func=None):
"""Edge propagation in heterogeneous graph is not supported.
"""
raise NotImplementedError('not supported')
def subgraph(self, nodes):
"""Return the subgraph induced on given nodes.
def multi_update_all(self, etype_dict, cross_reducer, apply_func=None):
r"""Send and receive messages along all edges.
Parameters
----------
nodes : dict[str, list or iterable]
A dictionary of node types to node ID array to construct
subgraph.
All nodes must exist in the graph.
It calculates:
Returns
-------
G : DGLHeteroSubGraph
The subgraph.
The nodes are relabeled so that node `i` of type `t` in the
subgraph is mapped to the ``nodes[i]`` of type `t` in the
original graph.
The edges are also relabeled.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `parent_nid` and `parent_eid` properties of the
subgraph.
"""
pass
.. math::
h_v^{new} = \sigma(\prod_{t\inT_e}\sum_{u\in\mathcal{N}_t(v)}\phi_t(
h_u, h_v, h_{uv}))
def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given
nodes in the list.
* ``etype_dict`` is a dictionary from edge type to a tuple of arguments for a
normal ``update_all``.
* :math:`\mathcal{N}_t(v)` is the set of predecessors of ``v`` connected by edge
type :math:`t`.
* ``cross_reducer`` specifies :math:`\prod_{t\inT_e}`
* ``apply_func`` specifies :math:`\sigma`.
Equivalent to
``[self.subgraph(nodes_list) for nodes_list in nodes]``
Examples
--------
TBD
Parameters
----------
nodes : a list of dict[str, list or iterable]
A list of type-ID dictionaries to construct corresponding
subgraphs. The dictionaries are of the same form as
:func:`subgraph`.
All nodes in all the list items must exist in the graph.
Returns
-------
G : A list of DGLHeteroSubGraph
The subgraphs.
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict of callable
``update_all`` arguments per edge type.
cross_reducer : str
Cross type reducer. One of "sum", "min", "max", "mean", "stack".
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
pass
def edge_subgraph(self, edges):
"""Return the subgraph induced on given edges.
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = defaultdict(list)
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(msg_func, reduce_func, [apply_func])')
mfunc, rfunc, afunc = args
scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
mfunc, rfunc, afunc,
outframe=outframe)
all_out[dtid].append(outframe)
Runtime.run(prog)
for dtid, frames in all_out.items():
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(frames, cross_reducer))
# apply
if apply_func is not None:
self.apply_nodes(apply_func, ALL, self.ntypes[dtid], inplace=False)
def prop_nodes(self,
nodes_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Propagate messages using graph traversal by triggering
:func:`pull()` on nodes.
The traversal order is specified by the ``nodes_generator``. It generates
node frontiers, which is a list or a tensor of nodes. The nodes in the
same frontier will be triggered together, while nodes in different frontiers
will be triggered according to the generating order.
Parameters
----------
edges : dict[etype, list or iterable]
A dictionary of edge types to edge ID array to construct
subgraph.
All edges must exist in the subgraph.
The edge type is characterized by a triplet of source type name,
destination type name, and edge type name.
node_generators : iterable, each element is a list or a tensor of node ids
The generator of node frontiers. It specifies which nodes perform
:func:`pull` at each timestep.
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
G : DGLHeteroSubGraph
The subgraph.
The edges are relabeled so that edge `i` of type `t` in the
subgraph is mapped to the ``edges[i]`` of type `t` in the
original graph.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via `parent_nid` and `parent_eid` properties of the
subgraph.
See Also
--------
prop_edges
"""
pass
def adjacency_matrix_scipy(self, etype, transpose=False, fmt='csr'):
"""Return the scipy adjacency matrix representation of edges with the
given edge type.
for node_frontier in nodes_generator:
self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
def prop_edges(self,
edges_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Propagate messages using graph traversal by triggering
:func:`send_and_recv()` on edges.
When transpose is True, a row represents the source and a column represents
a destination.
The traversal order is specified by the ``edges_generator``. It generates
edge frontiers. The edge frontiers should be of *valid edges type*.
See :func:`send` for more details.
The elements in the adajency matrix are edge ids.
Edges in the same frontier will be triggered together, while edges in
different frontiers will be triggered according to the generating order.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
transpose : bool, optional (default=False)
A flag to transpose the returned adjacency matrix.
fmt : str, optional (default='csr')
Indicates the format of returned adjacency matrix.
edges_generator : generator
The generator of edge frontiers.
message_func : callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
scipy.sparse.spmatrix
The scipy representation of adjacency matrix.
See Also
--------
prop_nodes
"""
pass
for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, message_func, reduce_func,
apply_node_func, etype=etype)
def adjacency_matrix(self, etype, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of edges with the
given edge type.
#################################################################
# Misc
#################################################################
By default, a row of returned adjacency matrix represents the
destination of an edge and the column represents the source.
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert this graph to networkx graph.
When transpose is True, a row represents the source and a column
represents a destination.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
transpose : bool, optional (default=False)
A flag to transpose the returned adjacency matrix.
ctx : context, optional (default=cpu)
The context of returned adjacency matrix.
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
SparseTensor
The adjacency matrix.
"""
pass
def incidence_matrix(self, etype, typestr, ctx=F.cpu()):
"""Return the incidence matrix representation of edges with the given
edge type.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix :math:`I`:
* ``in``:
networkx.DiGraph
The nx graph
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`
(or :math:`v` is the dst node of :math:`e`);
- :math:`I[v, e] = 0` otherwise.
* ``out``:
- :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`
(or :math:`v` is the src node of :math:`e`);
- :math:`I[v, e] = 0` otherwise.
* ``both``:
Examples
--------
- :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;
- :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;
- :math:`I[v, e] = 0` otherwise (including self-loop).
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
typestr : str
Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu)
The context of returned incidence matrix.
>>> import torch as th
>>> g = DGLGraph()
>>> g.add_nodes(5, {'n1': th.randn(5, 10)})
>>> g.add_edges([0,1,3,4], [2,4,0,3], {'e1': th.randn(4, 6)})
>>> nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
Returns
-------
SparseTensor
The incidence matrix.
"""
pass
def filter_nodes(self, ntype, predicate, nodes=ALL):
See Also
--------
dgl.to_networkx
"""
# TODO(minjie): multi-type support
assert len(self.ntypes) == 1
assert len(self.etypes) == 1
src, dst = self.edges()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
nx_graph = nx.MultiDiGraph() if self.is_multigraph else nx.DiGraph()
nx_graph.add_nodes_from(range(self.number_of_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid)
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
feat_dict = self._get_n_repr(0, nid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True):
eid = attr['id']
feat_dict = self._get_e_repr(0, eid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
return nx_graph
def filter_nodes(self, predicate, nodes=ALL, ntype=None):
"""Return a tensor of node IDs with the given node type that satisfy
the given predicate.
Parameters
----------
ntype : str
The node type.
predicate : callable
A function of signature ``func(nodes) -> tensor``.
``nodes`` are :class:`NodeBatch` objects as in :mod:`~dgl.udf`.
......@@ -2423,23 +2721,37 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
the batch satisfies the predicate.
nodes : int, iterable or tensor of ints
The nodes to filter on. Default value is all the nodes.
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph.
Returns
-------
tensor
The filtered nodes.
The nodes that satisfy the predicate.
"""
pass
ntid = self.get_ntype_id(ntype)
if is_all(nodes):
v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)))
else:
v = utils.toindex(nodes)
n_repr = self._get_n_repr(ntid, v)
nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu())
def filter_edges(self, etype, predicate, edges=ALL):
if is_all(nodes):
return F.nonzero_1d(n_mask)
else:
nodes = F.tensor(nodes)
return F.boolean_mask(nodes, n_mask)
def filter_edges(self, predicate, edges=ALL, etype=None):
"""Return a tensor of edge IDs with the given edge type that satisfy
the given predicate.
Parameters
----------
etype : tuple[str, str, str]
The edge type, characterized by a triplet of source type name,
destination type name, and edge type name.
predicate : callable
A function of signature ``func(edges) -> tensor``.
``edges`` are :class:`EdgeBatch` objects as in :mod:`~dgl.udf`.
......@@ -2449,114 +2761,456 @@ class DGLHeteroGraph(DGLBaseHeteroGraph):
edges : valid edges type
Edges on which to apply ``func``. See :func:`send` for valid
edges type. Default value is all the edges.
etype : str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
tensor
The filtered edges represented by their ids.
The edges that satisfy the predicate.
"""
pass
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(etid, eid)
def readonly(self, readonly_state=True):
"""Set this graph's readonly state in-place.
src_data = self._get_n_repr(stid, u)
edge_data = self._get_e_repr(etid, eid)
dst_data = self._get_n_repr(dtid, v)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges):
return F.nonzero_1d(e_mask)
else:
edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask)
def to(self, ctx): # pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Parameters
----------
readonly_state : bool, optional
New readonly state of the graph, defaults to True.
ctx : framework-specific context object
The context to move data to.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> G = dgl.DGLGraph()
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(torch.device('cuda:0'))
"""
pass
for i in range(len(self._node_frames)):
for k in self._node_frames[i].keys():
self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
for i in range(len(self._edge_frames)):
for k in self._edge_frames[i].keys():
self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
# TODO: replace this after implementing frame
# pylint: disable=useless-super-delegation
def __repr__(self):
return super(DGLHeteroGraph, self).__repr__()
def local_var(self):
"""Return a graph object that can be used in a local function scope.
The returned graph object shares the feature data and graph structure of this graph.
However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope.
If set, the local graph object will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> g = g.local_var()
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> g = g.local_var()
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
Notes
-----
Internally, the returned graph shares the same feature tensors, but construct a new
dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
graph will not reflect to the original graph. However, inplace operations do change
the shared tensor values, so will be reflected to the original graph. This function
also has little overhead when the number of feature tensors in this graph is small.
# pylint: disable=abstract-method
class DGLHeteroSubGraph(DGLHeteroGraph):
See Also
--------
local_var
Returns
-------
DGLGraph
The graph object that can be used as a local variable.
"""
local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(local_node_frames, self._node_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
local_node_frames,
local_edge_frames)
@contextmanager
def local_scope(self):
"""Enter a local scope context for this graph.
By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope.
If set, the local scope will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> with g.local_scope():
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> with g.local_scope():
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
See Also
--------
local_var
"""
old_nframes = self._node_frames
old_eframes = self._edge_frames
self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames]
self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(self._node_frames, old_nframes):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(self._edge_frames, old_eframes):
sync_frame_initializer(fr1._frame, fr2._frame)
yield
self._node_frames = old_nframes
self._edge_frames = old_eframes
############################################################
# Internal APIs
############################################################
def make_canonical_etypes(etypes, ntypes, metagraph):
"""Internal function to convert etype name to (srctype, etype, dsttype)
Parameters
----------
etypes : list of str
Edge type list
ntypes : list of str
Node type list
metagraph : GraphIndex
Meta graph.
Returns
-------
list of tuples (srctype, etype, dsttype)
"""
# sanity check
if len(etypes) != metagraph.number_of_edges():
raise DGLError('Length of edge type list must match the number of '
'edges in the metagraph. {} vs {}'.format(
len(etypes), metagraph.number_of_edges()))
if len(ntypes) != metagraph.number_of_nodes():
raise DGLError('Length of nodes type list must match the number of '
'nodes in the metagraph. {} vs {}'.format(
len(ntypes), metagraph.number_of_nodes()))
src, dst, eid = metagraph.edges()
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
return rst
def infer_ntype_from_dict(graph, etype_dict):
"""Infer node type from dictionary of edge type to values.
All the edge types in the dict must share the same destination node type
and the node type will be returned. Otherwise, throw error.
Parameters
----------
graph : DGLHeteroGraph
Graph
etype_dict : dict
Dictionary whose key is edge type
Returns
-------
str
Node type
"""
ntype = None
for ety in etype_dict:
_, _, dty = graph.to_canonical_etype(ety)
if ntype is None:
ntype = dty
if ntype != dty:
raise DGLError("Cannot infer destination node type from the dictionary. "
"A valid specification must make sure that all the edge "
"type keys share the same destination node type.")
return ntype
def pad_tuple(tup, length, pad_val=None):
"""Pad the given tuple to the given length.
If the input is not a tuple, convert it to a tuple of length one.
Return None if pad fails.
"""
if not isinstance(tup, tuple):
tup = (tup, )
if len(tup) > length:
return None
elif len(tup) == length:
return tup
else:
return tup + (pad_val,) * (length - len(tup))
def merge_frames(frames, reducer):
"""Merge input frames into one. Resolve conflict fields using reducer.
Parameters
----------
frames : list of FrameRef
Input frames
reducer : str
One of "sum", "max", "min", "mean", "stack"
Returns
-------
FrameRef
Merged frame
"""
if len(frames) == 1:
return frames[0]
if reducer == 'stack':
# TODO(minjie): Stack order does not matter. However, it must
# be consistent! Need to enforce one type of order.
def merger(flist):
flist = [F.unsqueeze(f, 1) for f in flist]
return F.stack(flist, 1)
else:
redfn = getattr(F, reducer, None)
if redfn is None:
raise DGLError('Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".')
def merger(flist):
return redfn(F.stack(flist, 0), 0)
ret = FrameRef(frame_like(frames[0]._frame))
keys = set()
for frm in frames:
keys.update(frm.keys())
for k in keys:
flist = []
for frm in frames:
if k in frm:
flist.append(frm[k])
if len(flist) > 1:
ret[k] = merger(flist)
else:
ret[k] = flist[0]
return ret
def combine_frames(frames, ids):
"""Merge the frames into one frame, taking the common columns.
Return None if there is no common columns.
Parameters
----------
parent : DGLHeteroGraph
The parent graph.
parent_nid : dict[str, utils.Index]
The type-specific parent node IDs for each type.
parent_eid : dict[etype, utils.Index]
The type-specific parent edge IDs for each type.
graph_idx : GraphIndex
The graph index
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph
frames : List[FrameRef]
List of frames
ids : List[int]
List of frame IDs
Returns
-------
FrameRef
The resulting frame
"""
# pylint: disable=unused-argument, super-init-not-called
def __init__(
self,
parent,
parent_nid,
parent_eid,
graph_idx,
shared=False):
pass
# find common columns and check if their schemes match
schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
for frame_id in ids:
frame = frames[frame_id]
for key, scheme in list(schemes.items()):
if key in frame.schemes:
if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
(key, frame.schemes[key], scheme))
else:
del schemes[key]
if len(schemes) == 0:
return None
# concatenate the columns
to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
return FrameRef(Frame(cols))
def combine_names(names, ids=None):
"""Combine the selected names into one new name.
Parameters
----------
names : list of str
String names
ids : numpy.ndarray, optional
Selected index
Returns
-------
str
"""
if ids is None:
return '+'.join(sorted(names))
else:
selected = sorted([names[i] for i in ids])
return '+'.join(selected)
class AdaptedHeteroGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLHeteroGraph
Graph
stid : int
Source node type id
dtid : int
Destination node type id
etid : int
Edge type id
"""
def __init__(self, graph, stid, dtid, etid):
self.graph = graph
self.stid = stid
self.dtid = dtid
self.etid = etid
@property
def parent_nid(self):
"""Get the parent node ids.
def gidx(self):
return self.graph._graph
The returned tensor dictionary can be used as a map from the node id
in this subgraph to the node id in the parent graph.
def num_src(self):
"""Number of source nodes."""
return self.graph._graph.number_of_nodes(self.stid)
Returns
-------
dict[str, Tensor]
The parent node id array for each type.
"""
pass
def num_dst(self):
"""Number of destination nodes."""
return self.graph._graph.number_of_nodes(self.dtid)
def num_edges(self):
"""Number of edges."""
return self.graph._graph.number_of_edges(self.etid)
@property
def parent_eid(self):
"""Get the parent edge ids.
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frames[self.stid]
The returned tensor dictionary can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frames[self.dtid]
Returns
-------
dict[etype, Tensor]
The parent edge id array for each type.
The edge types are characterized by a triplet of source type
name, destination type name, and edge type name.
"""
pass
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frames[self.etid]
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frames[self.etid]
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
pass
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index(self.etid)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(self.etid, val)
All old features will be removed.
"""
pass
def in_edges(self, nodes):
return self.graph._graph.in_edges(self.etid, nodes)
def map_to_subgraph_nid(self, parent_vids):
"""Map the node IDs in the parent graph to the node IDs in the
subgraph.
def out_edges(self, nodes):
return self.graph._graph.out_edges(self.etid, nodes)
Parameters
----------
parent_vids : dict[str, list or tensor]
The dictionary of node types to parent node ID array.
def edges(self, form):
return self.graph._graph.edges(self.etid, form)
Returns
-------
dict[str, tensor]
The node ID array in the subgraph of each node type.
"""
pass
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_unitgraph(self.etid, ctx)
def bits_needed(self):
return self.graph._graph.bits_needed(self.etid)
"""Module for heterogeneous graph index class definition."""
from __future__ import absolute_import
import numpy as np
import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError
......@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase):
return self.metagraph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
"""Get the unitgraph graph of the given edge/relation type.
Parameters
----------
......@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase):
Returns
-------
HeteroGraphIndex
The bipartite graph.
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def flatten_relations(self, etypes):
"""Convert the list of requested unitgraph graphs into a single unitgraph
graph.
Parameters
----------
etypes : list[int]
The edge/relation types.
Returns
-------
HeteroGraphIndex
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
def add_nodes(self, ntype, num):
"""Add nodes.
......@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase):
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph.
"""Return the number of integer bits needed to represent the unitgraph graph.
Parameters
----------
......@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase):
else:
raise Exception("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
fmt : str
Indicates the format of returned adjacency matrix.
return_edge_ids : bool
Indicates whether to return edge IDs or 1 as elements.
Returns
-------
scipy.sparse.spmatrix
The scipy representation of adjacency matrix.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
if return_edge_ids is None:
dgl_warning(
"Adjacency matrix by default currently returns edge IDs."
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
return_edge_ids = True
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy()
row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
else:
raise Exception("unknown format")
def incidence_matrix(self, etype, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix `I`:
* "in":
- I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);
- I[v, e] = 0 otherwise.
* "out":
- I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);
- I[v, e] = 0 otherwise.
* "both":
- I[v, e] = 1 if e is the in-edge of v;
- I[v, e] = -1 if e is the out-edge of v;
- I[v, e] = 0 otherwise (including self-loop).
Parameters
----------
etype : int
Edge type
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
Returns
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
src, dst, eid = self.edges(etype)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
srctype, dsttype = self.metagraph.find_edge(etype)
m = self.number_of_edges(etype)
if typestr == 'in':
n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'out':
n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'both':
assert srctype == dsttype, \
"'both' is supported only if source and destination type are the same"
n = self.number_of_nodes(srctype)
# first remove entries for self loops
mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask)
dst = F.boolean_mask(dst, mask)
eid = F.boolean_mask(eid, mask)
n_entries = F.shape(src)[0]
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
......@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase):
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite')
def get_bipartite(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device
@utils.cached_member(cache='_cache', prefix='unitgraph')
def get_unitgraph(self, etype, ctx):
"""Create a unitgraph graph from given edge type and copy to the given device
context.
Note: this internal function is for DGL scheduler use only
Parameters
----------
etype : int, or None
etype : int
If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type.
ctx : DGLContext
......@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase):
-------
HeteroGraphIndex
"""
g = self.get_relation_graph(etype) if etype is not None else self
g = self.get_relation_graph(etype)
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype):
......@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
#################################################################
# Creators
#################################################################
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col):
"""Create a unitgraph graph index from COO format
Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
......@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col):
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids):
"""Create a unitgraph graph index from CSR format
Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
......@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
return _CAPI_DGLHeteroCreateUnitGraphFromCSR(
int(num_ntypes), int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs):
def create_heterograph_from_relations(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
......
......@@ -3,3 +3,4 @@ from __future__ import absolute_import
from . import scheduler
from .runtime import Runtime
from .adapter import GraphAdapter
"""Temporary adapter to unify DGLGraph and HeteroGraph for scheduler.
NOTE(minjie): remove once all scheduler codes are migrated to heterograph
"""
from __future__ import absolute_import
from abc import ABC, abstractmethod
class GraphAdapter(ABC):
"""Temporary adapter class to unify DGLGraph and DGLHeteroGraph for schedulers."""
@property
@abstractmethod
def gidx(self):
"""Get graph index object."""
@abstractmethod
def num_src(self):
"""Number of source nodes."""
@abstractmethod
def num_dst(self):
"""Number of destination nodes."""
@abstractmethod
def num_edges(self):
"""Number of edges."""
@property
@abstractmethod
def srcframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def dstframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def edgeframe(self):
"""Frame to store edge features."""
@property
@abstractmethod
def msgframe(self):
"""Frame to store messages."""
@property
@abstractmethod
def msgindicator(self):
"""Message indicator tensor."""
@msgindicator.setter
@abstractmethod
def msgindicator(self, val):
"""Set new message indicator tensor."""
@abstractmethod
def in_edges(self, nodes):
"""Get in edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def out_edges(self, nodes):
"""Get out edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def edges(self, form):
"""Get all edges
Parameters
----------
form : str
"eid", "uv", etc.
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def get_immutable_gidx(self, ctx):
"""Get immutable graph index for kernel computation.
Parameters
----------
ctx : DGLContext
The context of the returned graph.
Returns
-------
GraphIndex
"""
@abstractmethod
def bits_needed(self):
"""Return the number of integer bits needed to represent the graph
Returns
-------
int
The number of bits needed
"""
......@@ -10,7 +10,6 @@ from . import ir
from .ir import var
def gen_degree_bucketing_schedule(
graph,
reduce_udf,
message_ids,
dst_nodes,
......@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule(
Parameters
----------
graph : DGLGraph
DGLGraph to use
reduce_udf : callable
The UDF to reduce messages.
message_ids : utils.Index
......@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule(
fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt)
rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt)
# vars
vbkt = var.IDX(vbkt)
mid = var.IDX(mid)
......@@ -144,7 +141,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key):
......@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data)
nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data)
return reduce_udf(nbatch)
return _rfunc_wrapper
def gen_group_apply_edge_schedule(
graph,
apply_func,
u, v, eid,
group_by,
......@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule(
Parameters
----------
graph : DGLGraph
DGLGraph to use
apply_func: callable
The edge_apply_func UDF
u: utils.Index
......@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule(
fd_list = []
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(graph, apply_func, deg,
_efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt))
# vars
var_u = var.IDX(u_bkt)
......@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets):
eids = split(eids)
return degs, uids, vids, eids
def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
"""Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data):
......@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
edge_data.keys())
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys())
ebatch = EdgeBatch(graph, (u, v, eid), reshaped_src_data,
ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper
......
......@@ -8,8 +8,6 @@ from .. import backend as F
from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir
from .ir import var
......@@ -30,22 +28,16 @@ __all__ = [
"schedule_pull"
]
def _dispatch(graph, method, *args, **kwargs):
graph_index = graph._graph
if isinstance(graph_index, GraphIndex):
return getattr(graph._graph, method)(*args, **kwargs)
elif isinstance(graph_index, HeteroGraphIndex):
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func):
"""get send schedule
def schedule_send(graph,
u, v, eid,
message_func,
msgframe=None):
"""Schedule send
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
u : utils.Index
Source nodes
v : utils.Index
......@@ -54,11 +46,13 @@ def schedule_send(graph, u, v, eid, message_func):
Ids of sending edges
message_func: callable or list of callable
The message function
msgframe : FrameRef, optional
The storage to write messages to. If None, use graph.msgframe.
"""
var_mf = var.FEAT_DICT(graph._msg_frame)
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_mf = var.FEAT_DICT(msgframe if msgframe is not None else graph.msgframe)
var_src_nf = var.FEAT_DICT(graph.srcframe)
var_dst_nf = var.FEAT_DICT(graph.dstframe)
var_ef = var.FEAT_DICT(graph.edgeframe)
var_eid = var.IDX(eid)
var_msg = _gen_send(graph=graph,
......@@ -73,19 +67,20 @@ def schedule_send(graph, u, v, eid, message_func):
# write tmp msg back
ir.WRITE_ROW_(var_mf, var_eid, var_msg)
# set message indicator to 1
graph._set_msg_index(graph._get_msg_index().set_items(eid, 1))
graph.msgindicator = graph.msgindicator.set_items(eid, 1)
def schedule_recv(graph,
recv_nodes,
reduce_func,
apply_func,
inplace):
inplace,
outframe=None):
"""Schedule recv.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
recv_nodes: utils.Index
Nodes to recv.
reduce_func: callable or list of callable
......@@ -94,10 +89,12 @@ def schedule_recv(graph,
The apply node function
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
"""
src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes)
src, dst, eid = graph.in_edges(recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
nonzero_idx = graph.msgindicator.get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx)
......@@ -106,9 +103,11 @@ def schedule_recv(graph,
# 1) all recv nodes are 0-degree nodes
# 2) no send has been called
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
schedule_apply_nodes(recv_nodes, apply_func, graph.dstframe,
inplace, outframe)
else:
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
# sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
......@@ -117,23 +116,24 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_out_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))
graph.msgindicator = graph.msgindicator.set_items(eid, 0)
if not graph.msgindicator.has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph.msgframe, name='mf'))
def schedule_snr(graph,
edge_tuples,
message_func,
reduce_func,
apply_func,
inplace):
inplace,
outframe=None):
"""Schedule send_and_recv.
Currently it builds a subgraph from edge_tuples with the same number of
......@@ -142,8 +142,8 @@ def schedule_snr(graph,
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
edge_tuples: tuple
A tuple of (src ids, dst ids, edge ids) representing edges to perform
send_and_recv
......@@ -155,12 +155,15 @@ def schedule_snr(graph,
The apply node function
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
"""
u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
# create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
......@@ -168,12 +171,11 @@ def schedule_snr(graph,
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes())
edge_tuples, graph.num_src(), graph.num_dst())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
reduced_feat = _gen_send_reduce(src_node_frame=graph.srcframe,
dst_node_frame=graph.dstframe,
edge_frame=graph.edgeframe,
message_func=message_func,
reduce_func=reduce_func,
var_send_edges=var_eid,
......@@ -182,52 +184,56 @@ def schedule_snr(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat,
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, reduced_feat,
apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_out_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph,
message_func,
reduce_func,
apply_func):
"""get send and recv schedule
apply_func,
outframe=None):
"""Get send and recv schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
message_func: callable or list of callable
The message function
reduce_func: callable or list of callable
The reduce function
apply_func: callable
The apply node function
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
"""
if graph._number_of_edges() == 0:
if graph.num_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph._number_of_dst_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
nodes = utils.toindex(slice(0, graph.num_dst()))
schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe)
else:
eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL
eid = utils.toindex(slice(0, graph.num_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL
# create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid)
# generate send + reduce
def uv_getter():
src, dst, _ = _dispatch(graph, 'edges', 'eid')
src, dst, _ = graph.edges('eid')
return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
reduced_feat = _gen_send_reduce(src_node_frame=graph.srcframe,
dst_node_frame=graph.dstframe,
edge_frame=graph.edgeframe,
message_func=message_func,
reduce_func=reduce_func,
var_send_edges=var_eid,
......@@ -236,50 +242,54 @@ def schedule_update_all(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
ir.WRITE_DICT_(var_dst_nf, final_feat)
ir.WRITE_DICT_(var_out_nf, final_feat)
def schedule_apply_nodes(graph,
v,
def schedule_apply_nodes(v,
apply_func,
inplace):
"""get apply nodes schedule
node_frame,
inplace,
outframe=None):
"""Get apply nodes schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
v : utils.Index
Nodes to apply
apply_func: callable
apply_func : callable
The apply node function
node_frame : FrameRef
Node feature frame.
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame.
Returns
-------
A list of executors for DGL Runtime
"""
var_v = var.IDX(v)
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_nf = var.FEAT_DICT(node_frame, name='nf')
var_out_nf = var_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, v, node_data)
nbatch = NodeBatch(v, node_data)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_v, applied_feat)
ir.WRITE_ROW_INPLACE_(var_out_nf, var_v, applied_feat)
else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat)
ir.WRITE_ROW_(var_out_nf, var_v, applied_feat)
def schedule_nodeflow_apply_nodes(graph,
layer_id,
v,
apply_func,
inplace):
"""get apply nodes schedule in NodeFlow.
"""Get apply nodes schedule in NodeFlow.
Parameters
----------
......@@ -302,7 +312,7 @@ def schedule_nodeflow_apply_nodes(graph,
var_v = var.IDX(v)
v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, v, node_data)
nbatch = NodeBatch(v, node_data)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
......@@ -315,13 +325,14 @@ def schedule_nodeflow_apply_nodes(graph,
def schedule_apply_edges(graph,
u, v, eid,
apply_func,
inplace):
"""get apply edges schedule
inplace,
outframe=None):
"""Get apply edges schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
u : utils.Index
Source nodes of edges to apply
v : utils.Index
......@@ -332,23 +343,25 @@ def schedule_apply_edges(graph,
The apply edge function
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.edge_frame.
Returns
-------
A list of executors for DGL Runtime
"""
# vars
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_src_nf = var.FEAT_DICT(graph.srcframe, 'uframe')
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'vframe')
var_ef = var.FEAT_DICT(graph.edgeframe, 'eframe')
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, 'out_ef')
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid)
# schedule apply edges
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
else:
ir.WRITE_ROW_(var_ef, var_eid, var_out)
......@@ -356,7 +369,7 @@ def schedule_nodeflow_apply_edges(graph, block_id,
u, v, eid,
apply_func,
inplace):
"""get apply edges schedule in NodeFlow.
"""Get apply edges schedule in NodeFlow.
Parameters
----------
......@@ -397,13 +410,14 @@ def schedule_push(graph,
message_func,
reduce_func,
apply_func,
inplace):
"""get push schedule
inplace,
outframe=None):
"""Get push schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
u : utils.Index
Source nodes for push
message_func: callable or list of callable
......@@ -414,26 +428,30 @@ def schedule_push(graph,
The apply node function
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
"""
u, v, eid = _dispatch(graph, 'out_edges', u)
u, v, eid = graph.out_edges(u)
if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled.
return
schedule_snr(graph, (u, v, eid),
message_func, reduce_func, apply_func, inplace)
message_func, reduce_func, apply_func,
inplace, outframe)
def schedule_pull(graph,
pull_nodes,
message_func,
reduce_func,
apply_func,
inplace):
"""get pull schedule
inplace,
outframe=None):
"""Get pull schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
pull_nodes : utils.Index
Destination nodes for pull
message_func: callable or list of callable
......@@ -444,20 +462,23 @@ def schedule_pull(graph,
The apply node function
inplace: bool
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use graph.dstframe.
"""
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier.
u, v, eid = _dispatch(graph, 'in_edges', pull_nodes)
u, v, eid = graph.in_edges(pull_nodes)
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func, inplace)
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace, outframe)
else:
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
# create vars
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
......@@ -465,31 +486,33 @@ def schedule_pull(graph,
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
(u, v, eid), graph.num_src(), graph.num_dst())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._src_frame,
graph._dst_frame, graph._edge_frame,
reduced_feat = _gen_send_reduce(graph.srcframe,
graph.dstframe, graph.edgeframe,
message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator,
out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(var_pull_nodes, var_dst_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_out_nf, var_pull_nodes, final_feat)
else:
ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_(var_out_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph,
u, v, eid,
apply_func,
group_by,
inplace):
"""group apply edges schedule
inplace,
outframe=None):
"""Group apply edges schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
graph: GraphAdaptor
Graph
u : utils.Index
Source nodes of edges to apply
v : utils.Index
......@@ -502,23 +525,22 @@ def schedule_group_apply_edge(graph,
Specify how to group edges. Expected to be either 'src' or 'dst'
inplace: bool
If True, the update will be done in place
Returns
-------
A list of executors for DGL Runtime
outframe : FrameRef, optional
The storage to write output data. If None, use graph.edgeframe.
"""
# vars
var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_src_nf = var.FEAT_DICT(graph.srcframe, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_ef = var.FEAT_DICT(graph.edgeframe, name='ef')
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, name='out_ef')
var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by,
db.gen_group_apply_edge_schedule(apply_func, u, v, eid, group_by,
var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid)
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
else:
ir.WRITE_ROW_(var_ef, var_eid, var_out)
ir.WRITE_ROW_(var_out_ef, var_eid, var_out)
def schedule_nodeflow_update_all(graph,
......@@ -526,7 +548,7 @@ def schedule_nodeflow_update_all(graph,
message_func,
reduce_func,
apply_func):
"""get update_all schedule in a block.
"""Get update_all schedule in a block.
Parameters
----------
......@@ -555,8 +577,7 @@ def schedule_nodeflow_update_all(graph,
return var.IDX(utils.toindex(src)), var.IDX(utils.toindex(dst))
adj_creator = lambda: spmv.build_gidx_and_mapping_block(graph, block_id)
out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._get_node_frame(block_id),
reduced_feat = _gen_send_reduce(src_node_frame=graph._get_node_frame(block_id),
dst_node_frame=graph._get_node_frame(block_id + 1),
edge_frame=graph._get_edge_frame(block_id),
message_func=message_func,
......@@ -567,7 +588,7 @@ def schedule_nodeflow_update_all(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_dest_nodes, var_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(var_dest_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat)
......@@ -626,8 +647,7 @@ def schedule_nodeflow_compute(graph,
graph, block_id, (u, v, eid))
out_map_creator = lambda nbits: _build_idx_map(utils.toindex(dest_nodes), nbits)
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._get_node_frame(block_id),
reduced_feat = _gen_send_reduce(src_node_frame=graph._get_node_frame(block_id),
dst_node_frame=graph._get_node_frame(block_id + 1),
edge_frame=graph._get_edge_frame(block_id),
message_func=message_func,
......@@ -638,7 +658,7 @@ def schedule_nodeflow_compute(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_dest_nodes, var_nf,
final_feat = _apply_with_accum(var_dest_nodes, var_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_dest_nodes, final_feat)
......@@ -680,7 +700,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s' % (func_name, str(func)))
return func
def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
"""Apply with accumulated features.
Paramters
......@@ -702,7 +722,7 @@ def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, var_nodes.data, node_data)
nbatch = NodeBatch(var_nodes.data, node_data)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
......@@ -716,7 +736,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
Parameters
----------
graph : DGLGraph
graph : GraphAdaptor
reduce_func : callable
edge_tuples : tuple of utils.Index
recv_nodes : utils.Index
......@@ -734,16 +754,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame.
# TODO(minjie): should replace this with an IR call to make the program
# stateless.
tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes)))
tmpframe = FrameRef(frame_like(graph.dstframe._frame, len(recv_nodes)))
# vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf')
var_msg = var.FEAT_DICT(graph.msgframe, 'msg')
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'nf')
var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
(src, dst, eid), graph.num_src(), graph.num_dst())
# using edge map instead of message map because messages are in global
# message frame
var_out_map = _build_idx_map(recv_nodes, nbits)
......@@ -757,12 +777,11 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
return var_out
else:
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes,
db.gen_degree_bucketing_schedule(rfunc, eid, dst, recv_nodes,
var_dst_nf, var_msg, var_out)
return var_out
def _gen_send_reduce(
graph,
src_node_frame,
dst_node_frame,
edge_frame,
......@@ -793,8 +812,6 @@ def _gen_send_reduce(
Parameters
----------
graph : DGLGraph
The graph
src_node_frame : NodeFrame
The node frame of the source nodes.
dst_node_frame : NodeFrame
......@@ -899,7 +916,7 @@ def _gen_send_reduce(
edge_map=edge_map)
else:
# generate UDF send schedule
var_mf = _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, var_u,
var_mf = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc)
# 6. Generate reduce
......@@ -916,18 +933,18 @@ def _gen_send_reduce(
else:
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data)))
db.gen_degree_bucketing_schedule(graph, rfunc, mid, var_v.data,
db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf,
var_out)
return var_out
def _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc):
def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc):
"""Internal function to generate send schedule for UDF message function."""
fdsrc = ir.READ_ROW(var_src_nf, u)
fddst = ir.READ_ROW(var_dst_nf, v)
fdedge = ir.READ_ROW(var_ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch(graph, (u.data, v.data, eid.data),
ebatch = EdgeBatch((u.data, v.data, eid.data),
src_data, edge_data, dst_data)
return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
......@@ -944,15 +961,15 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid)
if mfunc_is_list:
if eid.is_slice(0, graph._number_of_edges()):
if eid.is_slice(0, graph.num_edges()):
# full graph case
res = spmv.build_gidx_and_mapping_graph(graph)
else:
res = spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
(u, v, eid), graph.num_src(), graph.num_dst())
adj, edge_map, _ = res
# create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid)))
tmp_mfr = FrameRef(frame_like(graph.edgeframe._frame, len(eid)))
var_out = var.FEAT_DICT(data=tmp_mfr)
spmv.gen_v2e_spmv_schedule(graph=adj,
mfunc=mfunc,
......@@ -964,7 +981,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
edge_map=edge_map)
else:
# UDF send
var_out = _gen_udf_send(graph, var_src_nf, var_dst_nf, var_ef, var_u,
var_out = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc)
return var_out
......@@ -1004,5 +1021,4 @@ def _build_idx_map(idx, nbits):
old_to_new = F.zerocopy_to_dgl_ndarray(old_to_new)
return utils.CtxCachedObject(lambda ctx: nd.array(old_to_new, ctx=ctx))
_init_api("dgl.runtime.scheduler")
......@@ -6,8 +6,7 @@ from ..base import DGLError
from .. import backend as F
from .. import utils
from .. import ndarray as nd
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from ..heterograph_index import create_unitgraph_from_coo
from . import ir
from .ir import var
......@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters
----------
graph : DGLGraph or DGLHeteroGraph
The homogeneous graph, or a bipartite view of the heterogeneous graph.
graph : GraphAdapter
Graph
Returns
-------
......@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph):
nbits : int
Number of ints needed to represent the graph
"""
gidx = graph._graph
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
return graph.get_immutable_gidx, None, graph.bits_needed()
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of
nodes in the graph. Therefore, when doing SPMV, the src node data should be
all the node features.
The matrix is of shape (num_src, num_dst).
Parameters
---------
edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid)
num_src, num_dst : int
The number of source and destination nodes.
num_src : int
Number of source nodes.
num_dst : int
Number of destination nodes.
Returns
-------
......@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
Number of ints needed to represent the graph
"""
u, v, eid = edge_tuples
gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
gidx = create_unitgraph_from_coo(2, num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor()
nbits = gidx.bits_needed(0)
......@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx)))
return partial(gidx.get_bipartite, None), edge_map, nbits
return partial(gidx.get_unitgraph, 0), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
"""Build immutable graph index and mapping for node flow
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from .base import is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The class that can represent a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict
......@@ -24,8 +18,7 @@ class EdgeBatch(object):
The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
def __init__(self, edges, src_data, edge_data, dst_data):
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
......@@ -75,9 +68,6 @@ class EdgeBatch(object):
destination node and the edge id for the ith edge
in the batch.
"""
if is_all(self._edges[2]):
self._edges = self._edges[:2] + (utils.toindex(F.arange(
0, self._g.number_of_edges())),)
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
......@@ -104,9 +94,7 @@ class NodeBatch(object):
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
nodes : utils.Index
The node ids.
data : dict
The node features, in the form of ``dict``
......@@ -115,8 +103,7 @@ class NodeBatch(object):
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
def __init__(self, nodes, data, msgs=None):
self._nodes = nodes
self._data = data
self._msgs = msgs
......@@ -154,9 +141,6 @@ class NodeBatch(object):
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes()))
return self._nodes.tousertensor()
def batch_size(self):
......@@ -166,10 +150,7 @@ class NodeBatch(object):
-------
int
"""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch.
......
......@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits):
return F.astype(tensor, F.int32)
else:
return F.astype(tensor, F.int64)
def make_invmap(array, use_numpy=True):
"""Find the unique elements of the array and return another array with indices
to the array of unique elements."""
if use_numpy:
uniques = np.unique(array)
else:
uniques = list(set(array))
invmap = {x: i for i, x in enumerate(uniques)}
remapped = np.array([invmap[x] for x in array])
return uniques, invmap, remapped
......@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError
from . import backend as F
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
......@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
"""A EdgeView class to act as G.edges for a DGLGraph.
......@@ -256,111 +255,57 @@ class HeteroNodeView(object):
def __init__(self, graph):
self._graph = graph
def __getitem__(self, ntype):
return HeteroNodeTypeView(self._graph, ntype)
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
if not (key.start is None and key.stop is None
and key.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL))
nodes = ALL
ntype = None
elif isinstance(key, tuple):
nodes, ntype = key
elif isinstance(key, str):
nodes = ALL
ntype = key
else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes))
nodes = key
ntype = None
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, nodes))
def __call__(self):
def __call__(self, ntype=None):
"""Return the nodes."""
return F.arange(0, len(self))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
return F.arange(0, self._graph.number_of_nodes(ntype))
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype', '_nodes']
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
def __init__(self, graph, ntype, nodes):
self._graph = graph
self._ntype = ntype
self._ntid = self._graph.get_ntype_id(ntype)
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val})
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key)
self._graph._pop_n_repr(self._ntid, key)
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
return len(self._graph._node_frames[self._ntid])
def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
return iter(self._graph._node_frames[self._ntid])
def __repr__(self):
data = self._graph.get_n_repr(self._ntype)
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
......@@ -369,108 +314,59 @@ class HeteroEdgeView(object):
def __init__(self, graph):
self._graph = graph
def __getitem__(self, etype):
return HeteroEdgeTypeView(self._graph, etype)
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
if not (key.start is None and key.stop is None
and key.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL))
edges = ALL
etype = None
elif isinstance(key, tuple):
if len(key) == 3:
edges = ALL
etype = key
else:
edges = key
etype = None
elif isinstance(key, (str, tuple)):
edges = ALL
etype = key
else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges))
edges = key
etype = None
return EdgeSpace(data=HeteroEdgeDataView(self._graph, etype, edges))
def __call__(self):
"""Return the edges."""
return F.arange(0, len(self))
class HeteroEdgeTypeDataView(MutableMapping):
"""The data view class when G.edges[etype][...].data is called.
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype', '_edges']
class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.ndata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._etid = self._graph.get_etype_id(etype)
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val})
self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key)
self._graph._pop_e_repr(self._etid, key)
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
return len(self._graph._edge_frames[self._etid])
def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
return iter(self._graph._edge_frames[self._etid])
def __repr__(self):
data = self._graph.get_e_repr(self._etype)
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
for key in self._graph._edge_frames[self._etid]})
......@@ -4,10 +4,11 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "../c_api_common.h"
#include "./bipartite.h"
#include "./unit_graph.h"
using namespace dgl::runtime;
......@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// following heterograph:
//
// Meta graph: A -> B -> C
// Bipartite graphs:
// UnitGraph graphs:
// * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1)
//
......@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
subrels[etype] = Bipartite::CreateFromCOO(
subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0],
subedges[etype].src,
......@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graph must be bipartite graphs
// all relation graphs must have only one edge type
for (const auto rg : rel_graphs) {
CHECK_EQ(rg->NumVertexTypes(), 2) << "Each relation graph must be a bipartite graph.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
}
// create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
......@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype];
const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
size_t nv;
// # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType);
nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv;
else
CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType);
nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv;
else
......@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(
{vids[src_vtype], vids[dst_vtype]});
const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
std::vector<IdArray>({vids[src_vtype]}) :
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0];
}
......@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
}
}
// creator implementation
HeteroGraphPtr CreateBipartiteFromCOO(
int64_t num_src, int64_t num_dst, IdArray row, IdArray col) {
return Bipartite::CreateFromCOO(num_src, num_dst, row, col);
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_id_t> result_src, result_dst;
std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
// loop.
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t num_srctype_nodes = NumVertices(srctype);
size_t num_dsttype_nodes = NumVertices(dsttype);
if (srctype_offsets.count(srctype) == 0) {
srctype_offsets[srctype] = num_srctype_nodes;
srctype_set.push_back(srctype);
}
if (dsttype_offsets.count(dsttype) == 0) {
dsttype_offsets[dsttype] = num_dsttype_nodes;
dsttype_set.push_back(dsttype);
}
}
// Sort the node types so that we can compare the sets and decide whether a homograph
// should be returned.
std::sort(srctype_set.begin(), srctype_set.end());
std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these
// two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype];
srctype_offsets[ntype] = src_nodes;
src_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_srctype.push_back(ntype);
induced_srcid.push_back(j);
}
}
for (size_t i = 0; i < dsttype_set.size(); ++i) {
dgl_type_t ntype = dsttype_set[i];
size_t num_nodes = dsttype_offsets[ntype];
dsttype_offsets[ntype] = dst_nodes;
dst_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_dsttype.push_back(ntype);
induced_dstid.push_back(j);
}
}
HeteroGraphPtr CreateBipartiteFromCSR(
int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids) {
return Bipartite::CreateFromCSR(num_src, num_dst, indptr, indices, edge_ids);
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t srctype_offset = srctype_offsets[srctype];
size_t dsttype_offset = dsttype_offsets[dsttype];
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
result_dst.push_back(edges_dst_data[i] + dsttype_offset);
induced_etype.push_back(etype);
induced_eid.push_back(edges_eid_data[i]);
}
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->induced_srctype = aten::VecToIdArray(induced_srctype);
result->induced_srctype_set = aten::VecToIdArray(srctype_set);
result->induced_srcid = aten::VecToIdArray(induced_srcid);
result->induced_etype = aten::VecToIdArray(induced_etype);
result->induced_etype_set = aten::VecToIdArray(etypes);
result->induced_eid = aten::VecToIdArray(induced_eid);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype);
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set);
result->induced_dstid = aten::VecToIdArray(induced_dstid);
return FlattenedHeteroGraphPtr(result);
}
// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
......@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
IdArray row = args[2];
IdArray col = args[3];
auto hgptr = CreateBipartiteFromCOO(num_src, num_dst, row, col);
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
auto hgptr = UnitGraph::CreateFromCOO(nvtypes, num_src, num_dst, row, col);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
IdArray indptr = args[2];
IdArray indices = args[3];
IdArray edge_ids = args[4];
auto hgptr = CreateBipartiteFromCSR(num_src, num_dst, indptr, indices, edge_ids);
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray indptr = args[3];
IdArray indices = args[4];
IdArray edge_ids = args[5];
auto hgptr = UnitGraph::CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids);
*rv = HeteroGraphRef(hgptr);
});
......@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
if (hg->NumEdgeTypes() == 1) {
CHECK_EQ(etype, 0);
*rv = hg;
} else {
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes)
etypes_vec.push_back(val->data);
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
......@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits);
HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new);
});
......@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx);
HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
......
......@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph {
public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
uint64_t NumVertexTypes() const override {
return meta_graph_->NumVertices();
}
uint64_t NumEdgeTypes() const override {
return meta_graph_->NumEdges();
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
return relation_graphs_[etype];
......@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
private:
/*! \brief A map from edge type to bipartite graph */
/*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_;
/*! \brief A map from vert type to the number of verts in the type */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment