Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
......@@ -21,7 +21,9 @@ namespace dgl {
// Forward declaration
class BaseHeteroGraph;
class FlattenedHeteroGraph;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph;
/*!
......@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////////// query/operations on meta graph ////////////////////////
/*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const = 0;
virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}
/*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const = 0;
virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}
/*! \return the meta graph */
virtual GraphPtr meta_graph() const {
......@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}
static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
......@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);
};
/*! \brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object {
/*! \brief The graph */
HeteroGraphRef graph;
/*!
* \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_srctype;
/*!
* \brief The set of node types in parent graph appearing in source nodes.
*/
IdArray induced_srctype_set;
/*! \brief Mapping from source node ID to local node ID in parent graph */
IdArray induced_srcid;
/*!
* \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_etype;
/*!
* \brief The set of edge types in parent graph appearing in edges.
*/
IdArray induced_etype_set;
/*! \brief Mapping from edge ID to local edge ID in parent graph */
IdArray induced_eid;
/*!
* \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_dsttype;
/*!
* \brief The set of node types in parent graph appearing in destination nodes.
*/
IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */
IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set);
v->Visit("induced_srcid", &induced_srcid);
v->Visit("induced_etype", &induced_etype);
v->Visit("induced_etype_set", &induced_etype_set);
v->Visit("induced_eid", &induced_eid);
v->Visit("induced_dsttype", &induced_dsttype);
v->Visit("induced_dsttype_set", &induced_dsttype_set);
v->Visit("induced_dstid", &induced_dstid);
}
static constexpr const char* _type_key = "graph.FlattenedHeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(FlattenedHeteroGraph, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
......
......@@ -18,6 +18,7 @@ namespace runtime {
// forward declaration
class Object;
class ObjectRef;
class NDArray;
/*!
* \brief Visitor class to each object attribute.
......@@ -33,6 +34,7 @@ class AttrVisitor {
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
......
......@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
from ._ffi.base import DGLError, __version__
from .base import ALL
from .base import ALL, NTYPE, NID, ETYPE, EID
from .backend import load_backend
from .batched_graph import *
from .convert import *
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import *
......
......@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api
# A special symbol for selecting all nodes or edges.
ALL = "__ALL__"
# An alias for [:]
SLICE_FULL = slice(None, None, None)
# Reserved column names for storing parent node/edge types and IDs in flattened heterographs
NTYPE = '_TYPE'
NID = '_ID'
ETYPE = '_TYPE'
EID = '_ID'
def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges."""
......
This diff is collapsed.
......@@ -186,8 +186,8 @@ class Frame(MutableMapping):
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows``
will be ignored and inferred from the given data.
The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data.
"""
def __init__(self, data=None, num_rows=0):
if data is None:
......@@ -202,7 +202,7 @@ class Frame(MutableMapping):
elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = 0
self._num_rows = num_rows
# sanity check
for name, col in self._columns.items():
if len(col) != self._num_rows:
......@@ -880,23 +880,23 @@ class FrameRef(MutableMapping):
"""
return self._index.get_items(query)
def frame_like(other, num_rows):
"""Create a new frame that has the same scheme as the given one.
def frame_like(other, num_rows=None):
"""Create an empty frame that has the same initializer as the given one.
Parameters
----------
other : Frame
The given frame.
num_rows : int
The number of rows of the new one.
The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns
-------
Frame
The new frame.
"""
# TODO(minjie): scheme is not inherited at the moment. Fix this
# when moving per-col initializer to column scheme.
num_rows = other.num_rows if num_rows is None else num_rows
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
......
......@@ -11,7 +11,7 @@ from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index
from .runtime import ir, scheduler, Runtime
from .runtime import ir, scheduler, Runtime, GraphAdapter
from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
......@@ -49,14 +49,6 @@ class DGLBaseGraph(object):
"""
return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
......@@ -73,10 +65,6 @@ class DGLBaseGraph(object):
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self):
"""Return the number of edges in the graph.
......@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index):
self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph):
else:
v = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self,
v=v,
scheduler.schedule_apply_nodes(v=v,
apply_func=func,
node_frame=self._node_frame,
inplace=inplace)
Runtime.run(prog)
......@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self,
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
scheduler.schedule_apply_edges(AdaptedDGLGraph(self), u, v, eid, func, inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
......@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog:
scheduler.schedule_group_apply_edge(graph=self,
u=u,
v=v,
eid=eid,
scheduler.schedule_group_apply_edge(graph=AdaptedDGLGraph(self),
u=u, v=v, eid=eid,
apply_func=func,
group_by=group_by,
inplace=inplace)
......@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
scheduler.schedule_send(graph=AdaptedDGLGraph(self), u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)
......@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self,
scheduler.schedule_recv(graph=AdaptedDGLGraph(self),
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
......@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph):
return
with ir.prog() as prog:
scheduler.schedule_snr(graph=self,
scheduler.schedule_snr(graph=AdaptedDGLGraph(self),
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
......@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph):
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self,
scheduler.schedule_pull(graph=AdaptedDGLGraph(self),
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
......@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph):
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self,
scheduler.schedule_push(graph=AdaptedDGLGraph(self),
u=u,
message_func=message_func,
reduce_func=reduce_func,
......@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph):
assert reduce_func is not None
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self,
scheduler.schedule_update_all(graph=AdaptedDGLGraph(self),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
......@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph):
v = utils.toindex(nodes)
n_repr = self.get_n_repr(v)
nbatch = NodeBatch(self, v, n_repr)
nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes):
......@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph):
filter_nodes
"""
if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
......@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph):
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges):
......@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph):
yield
self._node_frame = old_nframe
self._edge_frame = old_eframe
############################################################
# Internal APIs
############################################################
class AdaptedDGLGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLGraph
Graph
"""
def __init__(self, graph):
self.graph = graph
@property
def gidx(self):
return self.graph._graph
def num_src(self):
"""Number of source nodes."""
return self.graph.number_of_nodes()
def num_dst(self):
"""Number of destination nodes."""
return self.graph.number_of_nodes()
def num_edges(self):
"""Number of edges."""
return self.graph.number_of_edges()
@property
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frame
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frame
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index()
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(val)
def in_edges(self, nodes):
return self.graph._graph.in_edges(nodes)
def out_edges(self, nodes):
return self.graph._graph.out_edges(nodes)
def edges(self, form):
return self.graph._graph.edges(form)
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_immutable_gidx(ctx)
def bits_needed(self):
return self.graph._graph.bits_needed()
......@@ -1129,10 +1129,13 @@ def from_edge_list(elist, is_multigraph, readonly):
Parameters
---------
elist : list
List of (u, v) edge tuple.
elist : list, tuple
List of (u, v) edge tuple, or a tuple of src/dst lists
"""
src, dst = zip(*elist)
if isinstance(elist, tuple):
src, dst = elist
else:
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
src_ids = utils.toindex(src)
......
This diff is collapsed.
"""Module for heterogeneous graph index class definition."""
from __future__ import absolute_import
import numpy as np
import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError
......@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase):
return self.metagraph.number_of_edges()
def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type.
"""Get the unitgraph graph of the given edge/relation type.
Parameters
----------
......@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase):
Returns
-------
HeteroGraphIndex
The bipartite graph.
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def flatten_relations(self, etypes):
"""Convert the list of requested unitgraph graphs into a single unitgraph
graph.
Parameters
----------
etypes : list[int]
The edge/relation types.
Returns
-------
HeteroGraphIndex
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
def add_nodes(self, ntype, num):
"""Add nodes.
......@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase):
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph.
"""Return the number of integer bits needed to represent the unitgraph graph.
Parameters
----------
......@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase):
else:
raise Exception("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
fmt : str
Indicates the format of returned adjacency matrix.
return_edge_ids : bool
Indicates whether to return edge IDs or 1 as elements.
Returns
-------
scipy.sparse.spmatrix
The scipy representation of adjacency matrix.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
if return_edge_ids is None:
dgl_warning(
"Adjacency matrix by default currently returns edge IDs."
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
return_edge_ids = True
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy()
row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
else:
raise Exception("unknown format")
def incidence_matrix(self, etype, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix `I`:
* "in":
- I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);
- I[v, e] = 0 otherwise.
* "out":
- I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);
- I[v, e] = 0 otherwise.
* "both":
- I[v, e] = 1 if e is the in-edge of v;
- I[v, e] = -1 if e is the out-edge of v;
- I[v, e] = 0 otherwise (including self-loop).
Parameters
----------
etype : int
Edge type
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
Returns
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
src, dst, eid = self.edges(etype)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
srctype, dsttype = self.metagraph.find_edge(etype)
m = self.number_of_edges(etype)
if typestr == 'in':
n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'out':
n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'both':
assert srctype == dsttype, \
"'both' is supported only if source and destination type are the same"
n = self.number_of_nodes(srctype)
# first remove entries for self loops
mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask)
dst = F.boolean_mask(dst, mask)
eid = F.boolean_mask(eid, mask)
n_entries = F.shape(src)[0]
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph.
......@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase):
eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite')
def get_bipartite(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device
@utils.cached_member(cache='_cache', prefix='unitgraph')
def get_unitgraph(self, etype, ctx):
"""Create a unitgraph graph from given edge type and copy to the given device
context.
Note: this internal function is for DGL scheduler use only
Parameters
----------
etype : int, or None
etype : int
If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type.
ctx : DGLContext
......@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase):
-------
HeteroGraphIndex
"""
g = self.get_relation_graph(etype) if etype is not None else self
g = self.get_relation_graph(etype)
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype):
......@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col):
"""Create a bipartite graph index from COO format
#################################################################
# Creators
#################################################################
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col):
"""Create a unitgraph graph index from COO format
Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
......@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col):
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format
def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids):
"""Create a unitgraph graph index from CSR format
Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
......@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateBipartiteFromCSR(
int(num_src), int(num_dst),
return _CAPI_DGLHeteroCreateUnitGraphFromCSR(
int(num_ntypes), int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs):
def create_heterograph_from_relations(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
......
......@@ -3,3 +3,4 @@ from __future__ import absolute_import
from . import scheduler
from .runtime import Runtime
from .adapter import GraphAdapter
"""Temporary adapter to unify DGLGraph and HeteroGraph for scheduler.
NOTE(minjie): remove once all scheduler codes are migrated to heterograph
"""
from __future__ import absolute_import
from abc import ABC, abstractmethod
class GraphAdapter(ABC):
"""Temporary adapter class to unify DGLGraph and DGLHeteroGraph for schedulers."""
@property
@abstractmethod
def gidx(self):
"""Get graph index object."""
@abstractmethod
def num_src(self):
"""Number of source nodes."""
@abstractmethod
def num_dst(self):
"""Number of destination nodes."""
@abstractmethod
def num_edges(self):
"""Number of edges."""
@property
@abstractmethod
def srcframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def dstframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def edgeframe(self):
"""Frame to store edge features."""
@property
@abstractmethod
def msgframe(self):
"""Frame to store messages."""
@property
@abstractmethod
def msgindicator(self):
"""Message indicator tensor."""
@msgindicator.setter
@abstractmethod
def msgindicator(self, val):
"""Set new message indicator tensor."""
@abstractmethod
def in_edges(self, nodes):
"""Get in edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def out_edges(self, nodes):
"""Get out edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def edges(self, form):
"""Get all edges
Parameters
----------
form : str
"eid", "uv", etc.
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def get_immutable_gidx(self, ctx):
"""Get immutable graph index for kernel computation.
Parameters
----------
ctx : DGLContext
The context of the returned graph.
Returns
-------
GraphIndex
"""
@abstractmethod
def bits_needed(self):
"""Return the number of integer bits needed to represent the graph
Returns
-------
int
The number of bits needed
"""
......@@ -10,7 +10,6 @@ from . import ir
from .ir import var
def gen_degree_bucketing_schedule(
graph,
reduce_udf,
message_ids,
dst_nodes,
......@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule(
Parameters
----------
graph : DGLGraph
DGLGraph to use
reduce_udf : callable
The UDF to reduce messages.
message_ids : utils.Index
......@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule(
fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt)
rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt)
# vars
vbkt = var.IDX(vbkt)
mid = var.IDX(mid)
......@@ -144,7 +141,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key):
......@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data)
nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data)
return reduce_udf(nbatch)
return _rfunc_wrapper
def gen_group_apply_edge_schedule(
graph,
apply_func,
u, v, eid,
group_by,
......@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule(
Parameters
----------
graph : DGLGraph
DGLGraph to use
apply_func: callable
The edge_apply_func UDF
u: utils.Index
......@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule(
fd_list = []
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(graph, apply_func, deg,
_efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt))
# vars
var_u = var.IDX(u_bkt)
......@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets):
eids = split(eids)
return degs, uids, vids, eids
def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
"""Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data):
......@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
edge_data.keys())
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys())
ebatch = EdgeBatch(graph, (u, v, eid), reshaped_src_data,
ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper
......
This diff is collapsed.
......@@ -6,8 +6,7 @@ from ..base import DGLError
from .. import backend as F
from .. import utils
from .. import ndarray as nd
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from ..heterograph_index import create_unitgraph_from_coo
from . import ir
from .ir import var
......@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters
----------
graph : DGLGraph or DGLHeteroGraph
The homogeneous graph, or a bipartite view of the heterogeneous graph.
graph : GraphAdapter
Graph
Returns
-------
......@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph):
nbits : int
Number of ints needed to represent the graph
"""
gidx = graph._graph
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
return graph.get_immutable_gidx, None, graph.bits_needed()
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of
nodes in the graph. Therefore, when doing SPMV, the src node data should be
all the node features.
The matrix is of shape (num_src, num_dst).
Parameters
---------
edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid)
num_src, num_dst : int
The number of source and destination nodes.
num_src : int
Number of source nodes.
num_dst : int
Number of destination nodes.
Returns
-------
......@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
Number of ints needed to represent the graph
"""
u, v, eid = edge_tuples
gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
gidx = create_unitgraph_from_coo(2, num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor()
nbits = gidx.bits_needed(0)
......@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx)))
return partial(gidx.get_bipartite, None), edge_map, nbits
return partial(gidx.get_unitgraph, 0), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
"""Build immutable graph index and mapping for node flow
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from .base import is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The class that can represent a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict
......@@ -24,8 +18,7 @@ class EdgeBatch(object):
The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
def __init__(self, edges, src_data, edge_data, dst_data):
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
......@@ -75,9 +68,6 @@ class EdgeBatch(object):
destination node and the edge id for the ith edge
in the batch.
"""
if is_all(self._edges[2]):
self._edges = self._edges[:2] + (utils.toindex(F.arange(
0, self._g.number_of_edges())),)
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
......@@ -104,9 +94,7 @@ class NodeBatch(object):
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
nodes : utils.Index
The node ids.
data : dict
The node features, in the form of ``dict``
......@@ -115,8 +103,7 @@ class NodeBatch(object):
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
def __init__(self, nodes, data, msgs=None):
self._nodes = nodes
self._data = data
self._msgs = msgs
......@@ -154,9 +141,6 @@ class NodeBatch(object):
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes()))
return self._nodes.tousertensor()
def batch_size(self):
......@@ -166,10 +150,7 @@ class NodeBatch(object):
-------
int
"""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch.
......
......@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits):
return F.astype(tensor, F.int32)
else:
return F.astype(tensor, F.int64)
def make_invmap(array, use_numpy=True):
"""Find the unique elements of the array and return another array with indices
to the array of unique elements."""
if use_numpy:
uniques = np.unique(array)
else:
uniques = list(set(array))
invmap = {x: i for i, x in enumerate(uniques)}
remapped = np.array([invmap[x] for x in array])
return uniques, invmap, remapped
......@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError
from . import backend as F
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
......@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
"""A EdgeView class to act as G.edges for a DGLGraph.
......@@ -256,111 +255,57 @@ class HeteroNodeView(object):
def __init__(self, graph):
self._graph = graph
def __getitem__(self, ntype):
return HeteroNodeTypeView(self._graph, ntype)
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
if not (key.start is None and key.stop is None
and key.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL))
nodes = ALL
ntype = None
elif isinstance(key, tuple):
nodes, ntype = key
elif isinstance(key, str):
nodes = ALL
ntype = key
else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes))
nodes = key
ntype = None
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, nodes))
def __call__(self):
def __call__(self, ntype=None):
"""Return the nodes."""
return F.arange(0, len(self))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
return F.arange(0, self._graph.number_of_nodes(ntype))
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype', '_nodes']
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
def __init__(self, graph, ntype, nodes):
self._graph = graph
self._ntype = ntype
self._ntid = self._graph.get_ntype_id(ntype)
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val})
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key)
self._graph._pop_n_repr(self._ntid, key)
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
return len(self._graph._node_frames[self._ntid])
def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
return iter(self._graph._node_frames[self._ntid])
def __repr__(self):
data = self._graph.get_n_repr(self._ntype)
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
......@@ -369,108 +314,59 @@ class HeteroEdgeView(object):
def __init__(self, graph):
self._graph = graph
def __getitem__(self, etype):
return HeteroEdgeTypeView(self._graph, etype)
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
if not (key.start is None and key.stop is None
and key.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL))
edges = ALL
etype = None
elif isinstance(key, tuple):
if len(key) == 3:
edges = ALL
etype = key
else:
edges = key
etype = None
elif isinstance(key, (str, tuple)):
edges = ALL
etype = key
else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges))
edges = key
etype = None
return EdgeSpace(data=HeteroEdgeDataView(self._graph, etype, edges))
def __call__(self):
"""Return the edges."""
return F.arange(0, len(self))
class HeteroEdgeTypeDataView(MutableMapping):
"""The data view class when G.edges[etype][...].data is called.
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype', '_edges']
class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.ndata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._etid = self._graph.get_etype_id(etype)
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val})
self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key)
self._graph._pop_e_repr(self._etid, key)
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
return len(self._graph._edge_frames[self._etid])
def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
return iter(self._graph._edge_frames[self._etid])
def __repr__(self):
data = self._graph.get_e_repr(self._etype)
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
for key in self._graph._edge_frames[self._etid]})
......@@ -4,10 +4,11 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "../c_api_common.h"
#include "./bipartite.h"
#include "./unit_graph.h"
using namespace dgl::runtime;
......@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// following heterograph:
//
// Meta graph: A -> B -> C
// Bipartite graphs:
// UnitGraph graphs:
// * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1)
//
......@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
subrels[etype] = Bipartite::CreateFromCOO(
subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0],
subedges[etype].src,
......@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graph must be bipartite graphs
// all relation graphs must have only one edge type
for (const auto rg : rel_graphs) {
CHECK_EQ(rg->NumVertexTypes(), 2) << "Each relation graph must be a bipartite graph.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
}
// create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
......@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype];
const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
size_t nv;
// # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType);
nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv;
else
CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType);
nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv;
else
......@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(
{vids[src_vtype], vids[dst_vtype]});
const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
std::vector<IdArray>({vids[src_vtype]}) :
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0];
}
......@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
}
}
// creator implementation
HeteroGraphPtr CreateBipartiteFromCOO(
int64_t num_src, int64_t num_dst, IdArray row, IdArray col) {
return Bipartite::CreateFromCOO(num_src, num_dst, row, col);
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_id_t> result_src, result_dst;
std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
// loop.
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t num_srctype_nodes = NumVertices(srctype);
size_t num_dsttype_nodes = NumVertices(dsttype);
if (srctype_offsets.count(srctype) == 0) {
srctype_offsets[srctype] = num_srctype_nodes;
srctype_set.push_back(srctype);
}
if (dsttype_offsets.count(dsttype) == 0) {
dsttype_offsets[dsttype] = num_dsttype_nodes;
dsttype_set.push_back(dsttype);
}
}
// Sort the node types so that we can compare the sets and decide whether a homograph
// should be returned.
std::sort(srctype_set.begin(), srctype_set.end());
std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these
// two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype];
srctype_offsets[ntype] = src_nodes;
src_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_srctype.push_back(ntype);
induced_srcid.push_back(j);
}
}
for (size_t i = 0; i < dsttype_set.size(); ++i) {
dgl_type_t ntype = dsttype_set[i];
size_t num_nodes = dsttype_offsets[ntype];
dsttype_offsets[ntype] = dst_nodes;
dst_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_dsttype.push_back(ntype);
induced_dstid.push_back(j);
}
}
HeteroGraphPtr CreateBipartiteFromCSR(
int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids) {
return Bipartite::CreateFromCSR(num_src, num_dst, indptr, indices, edge_ids);
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t srctype_offset = srctype_offsets[srctype];
size_t dsttype_offset = dsttype_offsets[dsttype];
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
result_dst.push_back(edges_dst_data[i] + dsttype_offset);
induced_etype.push_back(etype);
induced_eid.push_back(edges_eid_data[i]);
}
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->induced_srctype = aten::VecToIdArray(induced_srctype);
result->induced_srctype_set = aten::VecToIdArray(srctype_set);
result->induced_srcid = aten::VecToIdArray(induced_srcid);
result->induced_etype = aten::VecToIdArray(induced_etype);
result->induced_etype_set = aten::VecToIdArray(etypes);
result->induced_eid = aten::VecToIdArray(induced_eid);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype);
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set);
result->induced_dstid = aten::VecToIdArray(induced_dstid);
return FlattenedHeteroGraphPtr(result);
}
// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
......@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
IdArray row = args[2];
IdArray col = args[3];
auto hgptr = CreateBipartiteFromCOO(num_src, num_dst, row, col);
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
auto hgptr = UnitGraph::CreateFromCOO(nvtypes, num_src, num_dst, row, col);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0];
int64_t num_dst = args[1];
IdArray indptr = args[2];
IdArray indices = args[3];
IdArray edge_ids = args[4];
auto hgptr = CreateBipartiteFromCSR(num_src, num_dst, indptr, indices, edge_ids);
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray indptr = args[3];
IdArray indices = args[4];
IdArray edge_ids = args[5];
auto hgptr = UnitGraph::CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids);
*rv = HeteroGraphRef(hgptr);
});
......@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
if (hg->NumEdgeTypes() == 1) {
CHECK_EQ(etype, 0);
*rv = hg;
} else {
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes)
etypes_vec.push_back(val->data);
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
......@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits);
HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new);
});
......@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx);
HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
......
......@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph {
public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
uint64_t NumVertexTypes() const override {
return meta_graph_->NumVertices();
}
uint64_t NumEdgeTypes() const override {
return meta_graph_->NumEdges();
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
return relation_graphs_[etype];
......@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
private:
/*! \brief A map from edge type to bipartite graph */
/*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_;
/*! \brief A map from vert type to the number of verts in the type */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment