Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
...@@ -21,7 +21,9 @@ namespace dgl { ...@@ -21,7 +21,9 @@ namespace dgl {
// Forward declaration // Forward declaration
class BaseHeteroGraph; class BaseHeteroGraph;
class FlattenedHeteroGraph;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr; typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph; struct HeteroSubgraph;
/*! /*!
...@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -46,10 +48,14 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////////// query/operations on meta graph //////////////////////// ////////////////////////// query/operations on meta graph ////////////////////////
/*! \return the number of vertex types */ /*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const = 0; virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}
/*! \return the number of edge types */ /*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const = 0; virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}
/*! \return the meta graph */ /*! \return the meta graph */
virtual GraphPtr meta_graph() const { virtual GraphPtr meta_graph() const {
...@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -351,6 +357,17 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph EdgeSubgraph( virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0; const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}
static constexpr const char* _type_key = "graph.HeteroGraph"; static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
...@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object { ...@@ -381,6 +398,62 @@ struct HeteroSubgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);
}; };
/*! \brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object {
/*! \brief The graph */
HeteroGraphRef graph;
/*!
* \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_srctype;
/*!
* \brief The set of node types in parent graph appearing in source nodes.
*/
IdArray induced_srctype_set;
/*! \brief Mapping from source node ID to local node ID in parent graph */
IdArray induced_srcid;
/*!
* \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_etype;
/*!
* \brief The set of edge types in parent graph appearing in edges.
*/
IdArray induced_etype_set;
/*! \brief Mapping from edge ID to local edge ID in parent graph */
IdArray induced_eid;
/*!
* \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
*/
IdArray induced_dsttype;
/*!
* \brief The set of node types in parent graph appearing in destination nodes.
*/
IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */
IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set);
v->Visit("induced_srcid", &induced_srcid);
v->Visit("induced_etype", &induced_etype);
v->Visit("induced_etype_set", &induced_etype_set);
v->Visit("induced_eid", &induced_eid);
v->Visit("induced_dsttype", &induced_dsttype);
v->Visit("induced_dsttype_set", &induced_dsttype_set);
v->Visit("induced_dstid", &induced_dstid);
}
static constexpr const char* _type_key = "graph.FlattenedHeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(FlattenedHeteroGraph, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Define HeteroSubgraphRef // Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph); DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
......
...@@ -18,6 +18,7 @@ namespace runtime { ...@@ -18,6 +18,7 @@ namespace runtime {
// forward declaration // forward declaration
class Object; class Object;
class ObjectRef; class ObjectRef;
class NDArray;
/*! /*!
* \brief Visitor class to each object attribute. * \brief Visitor class to each object attribute.
...@@ -33,6 +34,7 @@ class AttrVisitor { ...@@ -33,6 +34,7 @@ class AttrVisitor {
virtual void Visit(const char* key, bool* value) = 0; virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0; virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0; virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum, template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) { void Visit(const char* key, ENum* ptr) {
......
...@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode ...@@ -13,9 +13,10 @@ from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
from ._ffi.base import DGLError, __version__ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL, NTYPE, NID, ETYPE, EID
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .convert import *
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
......
...@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api ...@@ -8,6 +8,13 @@ from ._ffi.function import _init_internal_api
# A special symbol for selecting all nodes or edges. # A special symbol for selecting all nodes or edges.
ALL = "__ALL__" ALL = "__ALL__"
# An alias for [:]
SLICE_FULL = slice(None, None, None)
# Reserved column names for storing parent node/edge types and IDs in flattened heterographs
NTYPE = '_TYPE'
NID = '_ID'
ETYPE = '_TYPE'
EID = '_ID'
def is_all(arg): def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
......
This diff is collapsed.
...@@ -186,8 +186,8 @@ class Frame(MutableMapping): ...@@ -186,8 +186,8 @@ class Frame(MutableMapping):
update on one will not reflect to the other. The inplace update will update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container. be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0] num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows`` The number of rows in this frame. If ``data`` is provided and is not empty,
will be ignored and inferred from the given data. ``num_rows`` will be ignored and inferred from the given data.
""" """
def __init__(self, data=None, num_rows=0): def __init__(self, data=None, num_rows=0):
if data is None: if data is None:
...@@ -202,7 +202,7 @@ class Frame(MutableMapping): ...@@ -202,7 +202,7 @@ class Frame(MutableMapping):
elif len(self._columns) != 0: elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values()))) self._num_rows = len(next(iter(self._columns.values())))
else: else:
self._num_rows = 0 self._num_rows = num_rows
# sanity check # sanity check
for name, col in self._columns.items(): for name, col in self._columns.items():
if len(col) != self._num_rows: if len(col) != self._num_rows:
...@@ -880,23 +880,23 @@ class FrameRef(MutableMapping): ...@@ -880,23 +880,23 @@ class FrameRef(MutableMapping):
""" """
return self._index.get_items(query) return self._index.get_items(query)
def frame_like(other, num_rows): def frame_like(other, num_rows=None):
"""Create a new frame that has the same scheme as the given one. """Create an empty frame that has the same initializer as the given one.
Parameters Parameters
---------- ----------
other : Frame other : Frame
The given frame. The given frame.
num_rows : int num_rows : int
The number of rows of the new one. The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns Returns
------- -------
Frame Frame
The new frame. The new frame.
""" """
# TODO(minjie): scheme is not inherited at the moment. Fix this num_rows = other.num_rows if num_rows is None else num_rows
# when moving per-col initializer to column scheme.
newf = Frame(num_rows=num_rows) newf = Frame(num_rows=num_rows)
# set global initializr # set global initializr
if other.get_initializer() is None: if other.get_initializer() is None:
......
...@@ -11,7 +11,7 @@ from . import backend as F ...@@ -11,7 +11,7 @@ from . import backend as F
from . import init from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index from . import graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime, GraphAdapter
from . import utils from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
...@@ -49,14 +49,6 @@ class DGLBaseGraph(object): ...@@ -49,14 +49,6 @@ class DGLBaseGraph(object):
""" """
return self._graph.number_of_nodes() return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self): def __len__(self):
"""Return the number of nodes in the graph.""" """Return the number of nodes in the graph."""
return self.number_of_nodes() return self.number_of_nodes()
...@@ -73,10 +65,6 @@ class DGLBaseGraph(object): ...@@ -73,10 +65,6 @@ class DGLBaseGraph(object):
""" """
return self._graph.is_readonly() return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges in the graph. """Return the number of edges in the graph.
...@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -951,14 +939,6 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index): def _set_msg_index(self, index):
self._msg_index = index self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
...@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -2089,9 +2069,9 @@ class DGLGraph(DGLBaseGraph):
else: else:
v = utils.toindex(v) v = utils.toindex(v)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self, scheduler.schedule_apply_nodes(v=v,
v=v,
apply_func=func, apply_func=func,
node_frame=self._node_frame,
inplace=inplace) inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
...@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2159,12 +2139,7 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid) u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self, scheduler.schedule_apply_edges(AdaptedDGLGraph(self), u, v, eid, func, inplace)
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, inplace=False): def group_apply_edges(self, group_by, func, edges=ALL, inplace=False):
...@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -2241,10 +2216,8 @@ class DGLGraph(DGLBaseGraph):
u, v, _ = self._graph.find_edges(eid) u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_group_apply_edge(graph=self, scheduler.schedule_group_apply_edge(graph=AdaptedDGLGraph(self),
u=u, u=u, v=v, eid=eid,
v=v,
eid=eid,
apply_func=func, apply_func=func,
group_by=group_by, group_by=group_by,
inplace=inplace) inplace=inplace)
...@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2308,7 +2281,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid, scheduler.schedule_send(graph=AdaptedDGLGraph(self), u=u, v=v, eid=eid,
message_func=message_func) message_func=message_func)
Runtime.run(prog) Runtime.run(prog)
...@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2407,7 +2380,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_recv(graph=self, scheduler.schedule_recv(graph=AdaptedDGLGraph(self),
recv_nodes=v, recv_nodes=v,
reduce_func=reduce_func, reduce_func=reduce_func,
apply_func=apply_node_func, apply_func=apply_node_func,
...@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2515,7 +2488,7 @@ class DGLGraph(DGLBaseGraph):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_snr(graph=self, scheduler.schedule_snr(graph=AdaptedDGLGraph(self),
edge_tuples=(u, v, eid), edge_tuples=(u, v, eid),
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2618,7 +2591,7 @@ class DGLGraph(DGLBaseGraph):
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_pull(graph=self, scheduler.schedule_pull(graph=AdaptedDGLGraph(self),
pull_nodes=v, pull_nodes=v,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2715,7 +2688,7 @@ class DGLGraph(DGLBaseGraph):
if len(u) == 0: if len(u) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_push(graph=self, scheduler.schedule_push(graph=AdaptedDGLGraph(self),
u=u, u=u,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2762,7 +2735,7 @@ class DGLGraph(DGLBaseGraph):
assert reduce_func is not None assert reduce_func is not None
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_update_all(graph=self, scheduler.schedule_update_all(graph=AdaptedDGLGraph(self),
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
apply_func=apply_node_func) apply_func=apply_node_func)
...@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3219,7 +3192,7 @@ class DGLGraph(DGLBaseGraph):
v = utils.toindex(nodes) v = utils.toindex(nodes)
n_repr = self.get_n_repr(v) n_repr = self.get_n_repr(v)
nbatch = NodeBatch(self, v, n_repr) nbatch = NodeBatch(v, n_repr)
n_mask = F.copy_to(predicate(nbatch), F.cpu()) n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes): if is_all(nodes):
...@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -3277,8 +3250,8 @@ class DGLGraph(DGLBaseGraph):
filter_nodes filter_nodes
""" """
if is_all(edges): if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges('eid') u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
...@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3292,7 +3265,7 @@ class DGLGraph(DGLBaseGraph):
src_data = self.get_n_repr(u) src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid) edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v) dst_data = self.get_n_repr(v)
ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data) ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
e_mask = F.copy_to(predicate(ebatch), F.cpu()) e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges): if is_all(edges):
...@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph): ...@@ -3492,3 +3465,79 @@ class DGLGraph(DGLBaseGraph):
yield yield
self._node_frame = old_nframe self._node_frame = old_nframe
self._edge_frame = old_eframe self._edge_frame = old_eframe
############################################################
# Internal APIs
############################################################
class AdaptedDGLGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLGraph
Graph
"""
def __init__(self, graph):
self.graph = graph
@property
def gidx(self):
return self.graph._graph
def num_src(self):
"""Number of source nodes."""
return self.graph.number_of_nodes()
def num_dst(self):
"""Number of destination nodes."""
return self.graph.number_of_nodes()
def num_edges(self):
"""Number of edges."""
return self.graph.number_of_edges()
@property
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frame
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frame
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frame
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index()
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(val)
def in_edges(self, nodes):
return self.graph._graph.in_edges(nodes)
def out_edges(self, nodes):
return self.graph._graph.out_edges(nodes)
def edges(self, form):
return self.graph._graph.edges(form)
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_immutable_gidx(ctx)
def bits_needed(self):
return self.graph._graph.bits_needed()
...@@ -1129,9 +1129,12 @@ def from_edge_list(elist, is_multigraph, readonly): ...@@ -1129,9 +1129,12 @@ def from_edge_list(elist, is_multigraph, readonly):
Parameters Parameters
--------- ---------
elist : list elist : list, tuple
List of (u, v) edge tuple. List of (u, v) edge tuple, or a tuple of src/dst lists
""" """
if isinstance(elist, tuple):
src, dst = elist
else:
src, dst = zip(*elist) src, dst = zip(*elist)
src = np.array(src) src = np.array(src)
dst = np.array(dst) dst = np.array(dst)
......
This diff is collapsed.
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import scipy
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
...@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -48,7 +51,7 @@ class HeteroGraphIndex(ObjectBase):
return self.metagraph.number_of_edges() return self.metagraph.number_of_edges()
def get_relation_graph(self, etype): def get_relation_graph(self, etype):
"""Get the bipartite graph of the given edge/relation type. """Get the unitgraph graph of the given edge/relation type.
Parameters Parameters
---------- ----------
...@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase): ...@@ -58,10 +61,26 @@ class HeteroGraphIndex(ObjectBase):
Returns Returns
------- -------
HeteroGraphIndex HeteroGraphIndex
The bipartite graph. The unitgraph graph.
""" """
return _CAPI_DGLHeteroGetRelationGraph(self, int(etype)) return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))
def flatten_relations(self, etypes):
"""Convert the list of requested unitgraph graphs into a single unitgraph
graph.
Parameters
----------
etypes : list[int]
The edge/relation types.
Returns
-------
HeteroGraphIndex
The unitgraph graph.
"""
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
def add_nodes(self, ntype, num): def add_nodes(self, ntype, num):
"""Add nodes. """Add nodes.
...@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -131,7 +150,7 @@ class HeteroGraphIndex(ObjectBase):
return _CAPI_DGLHeteroNumBits(self) return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype): def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the bipartite graph. """Return the number of integer bits needed to represent the unitgraph graph.
Parameters Parameters
---------- ----------
...@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase): ...@@ -658,6 +677,146 @@ class HeteroGraphIndex(ObjectBase):
else: else:
raise Exception("unknown format") raise Exception("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
fmt : str
Indicates the format of returned adjacency matrix.
return_edge_ids : bool
Indicates whether to return edge IDs or 1 as elements.
Returns
-------
scipy.sparse.spmatrix
The scipy representation of adjacency matrix.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
if return_edge_ids is None:
dgl_warning(
"Adjacency matrix by default currently returns edge IDs."
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
return_edge_ids = True
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy()
row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
else:
raise Exception("unknown format")
def incidence_matrix(self, etype, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
the number of nodes and m is the number of edges. Each nnz
value indicating whether the edge is incident to the node
or not.
There are three types of an incidence matrix `I`:
* "in":
- I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);
- I[v, e] = 0 otherwise.
* "out":
- I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);
- I[v, e] = 0 otherwise.
* "both":
- I[v, e] = 1 if e is the in-edge of v;
- I[v, e] = -1 if e is the out-edge of v;
- I[v, e] = 0 otherwise (including self-loop).
Parameters
----------
etype : int
Edge type
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
Returns
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
src, dst, eid = self.edges(etype)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
srctype, dsttype = self.metagraph.find_edge(etype)
m = self.number_of_edges(etype)
if typestr == 'in':
n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'out':
n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'both':
assert srctype == dsttype, \
"'both' is supported only if source and destination type are the same"
n = self.number_of_nodes(srctype)
# first remove entries for self loops
mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask)
dst = F.boolean_mask(dst, mask)
eid = F.boolean_mask(eid, mask)
n_entries = F.shape(src)[0]
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def node_subgraph(self, induced_nodes): def node_subgraph(self, induced_nodes):
"""Return the induced node subgraph. """Return the induced node subgraph.
...@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase): ...@@ -696,16 +855,16 @@ class HeteroGraphIndex(ObjectBase):
eids = [edges.todgltensor() for edges in induced_edges] eids = [edges.todgltensor() for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes) return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='bipartite') @utils.cached_member(cache='_cache', prefix='unitgraph')
def get_bipartite(self, etype, ctx): def get_unitgraph(self, etype, ctx):
"""Create a bipartite graph from given edge type and copy to the given device """Create a unitgraph graph from given edge type and copy to the given device
context. context.
Note: this internal function is for DGL scheduler use only Note: this internal function is for DGL scheduler use only
Parameters Parameters
---------- ----------
etype : int, or None etype : int
If the graph index is a Bipartite graph index, this argument must be None. If the graph index is a Bipartite graph index, this argument must be None.
Otherwise, it represents the edge type. Otherwise, it represents the edge type.
ctx : DGLContext ctx : DGLContext
...@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -715,7 +874,7 @@ class HeteroGraphIndex(ObjectBase):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
g = self.get_relation_graph(etype) if etype is not None else self g = self.get_relation_graph(etype)
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx) return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
def get_csr_shuffle_order(self, etype): def get_csr_shuffle_order(self, etype):
...@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -778,11 +937,17 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data) for v in ret]
def create_bipartite_from_coo(num_src, num_dst, row, col): #################################################################
"""Create a bipartite graph index from COO format # Creators
#################################################################
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col):
"""Create a unitgraph graph index from COO format
Parameters Parameters
---------- ----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int num_src : int
Number of nodes in the src type. Number of nodes in the src type.
num_dst : int num_dst : int
...@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col): ...@@ -796,14 +961,16 @@ def create_bipartite_from_coo(num_src, num_dst, row, col):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
return _CAPI_DGLHeteroCreateBipartiteFromCOO( return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_src), int(num_dst), row.todgltensor(), col.todgltensor()) int(num_ntypes), int(num_src), int(num_dst), row.todgltensor(), col.todgltensor())
def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids): def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids):
"""Create a bipartite graph index from CSR format """Create a unitgraph graph index from CSR format
Parameters Parameters
---------- ----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int num_src : int
Number of nodes in the src type. Number of nodes in the src type.
num_dst : int num_dst : int
...@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids): ...@@ -819,11 +986,11 @@ def create_bipartite_from_csr(num_src, num_dst, indptr, indices, edge_ids):
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
return _CAPI_DGLHeteroCreateBipartiteFromCSR( return _CAPI_DGLHeteroCreateUnitGraphFromCSR(
int(num_src), int(num_dst), int(num_ntypes), int(num_src), int(num_dst),
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor()) indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor())
def create_heterograph(metagraph, rel_graphs): def create_heterograph_from_relations(metagraph, rel_graphs):
"""Create a heterograph from metagraph and graphs of every relation. """Create a heterograph from metagraph and graphs of every relation.
Parameters Parameters
......
...@@ -3,3 +3,4 @@ from __future__ import absolute_import ...@@ -3,3 +3,4 @@ from __future__ import absolute_import
from . import scheduler from . import scheduler
from .runtime import Runtime from .runtime import Runtime
from .adapter import GraphAdapter
"""Temporary adapter to unify DGLGraph and HeteroGraph for scheduler.
NOTE(minjie): remove once all scheduler codes are migrated to heterograph
"""
from __future__ import absolute_import
from abc import ABC, abstractmethod
class GraphAdapter(ABC):
"""Temporary adapter class to unify DGLGraph and DGLHeteroGraph for schedulers."""
@property
@abstractmethod
def gidx(self):
"""Get graph index object."""
@abstractmethod
def num_src(self):
"""Number of source nodes."""
@abstractmethod
def num_dst(self):
"""Number of destination nodes."""
@abstractmethod
def num_edges(self):
"""Number of edges."""
@property
@abstractmethod
def srcframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def dstframe(self):
"""Frame to store source node features."""
@property
@abstractmethod
def edgeframe(self):
"""Frame to store edge features."""
@property
@abstractmethod
def msgframe(self):
"""Frame to store messages."""
@property
@abstractmethod
def msgindicator(self):
"""Message indicator tensor."""
@msgindicator.setter
@abstractmethod
def msgindicator(self, val):
"""Set new message indicator tensor."""
@abstractmethod
def in_edges(self, nodes):
"""Get in edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def out_edges(self, nodes):
"""Get out edges
Parameters
----------
nodes : utils.Index
Nodes
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def edges(self, form):
"""Get all edges
Parameters
----------
form : str
"eid", "uv", etc.
Returns
-------
tuple of utils.Index
(src, dst, eid)
"""
@abstractmethod
def get_immutable_gidx(self, ctx):
"""Get immutable graph index for kernel computation.
Parameters
----------
ctx : DGLContext
The context of the returned graph.
Returns
-------
GraphIndex
"""
@abstractmethod
def bits_needed(self):
"""Return the number of integer bits needed to represent the graph
Returns
-------
int
The number of bits needed
"""
...@@ -10,7 +10,6 @@ from . import ir ...@@ -10,7 +10,6 @@ from . import ir
from .ir import var from .ir import var
def gen_degree_bucketing_schedule( def gen_degree_bucketing_schedule(
graph,
reduce_udf, reduce_udf,
message_ids, message_ids,
dst_nodes, dst_nodes,
...@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule( ...@@ -28,8 +27,6 @@ def gen_degree_bucketing_schedule(
Parameters Parameters
---------- ----------
graph : DGLGraph
DGLGraph to use
reduce_udf : callable reduce_udf : callable
The UDF to reduce messages. The UDF to reduce messages.
message_ids : utils.Index message_ids : utils.Index
...@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule( ...@@ -56,7 +53,7 @@ def gen_degree_bucketing_schedule(
fd_list = [] fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids): for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc # create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt) rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt)
# vars # vars
vbkt = var.IDX(vbkt) vbkt = var.IDX(vbkt)
mid = var.IDX(mid) mid = var.IDX(mid)
...@@ -144,7 +141,7 @@ def _process_node_buckets(buckets): ...@@ -144,7 +141,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt): def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF.""" """Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data): def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key): def _reshaped_getter(key):
...@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt): ...@@ -152,12 +149,11 @@ def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:] new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys()) reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data) nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data)
return reduce_udf(nbatch) return reduce_udf(nbatch)
return _rfunc_wrapper return _rfunc_wrapper
def gen_group_apply_edge_schedule( def gen_group_apply_edge_schedule(
graph,
apply_func, apply_func,
u, v, eid, u, v, eid,
group_by, group_by,
...@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule( ...@@ -175,8 +171,6 @@ def gen_group_apply_edge_schedule(
Parameters Parameters
---------- ----------
graph : DGLGraph
DGLGraph to use
apply_func: callable apply_func: callable
The edge_apply_func UDF The edge_apply_func UDF
u: utils.Index u: utils.Index
...@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule( ...@@ -209,7 +203,7 @@ def gen_group_apply_edge_schedule(
fd_list = [] fd_list = []
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids): for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc # create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(graph, apply_func, deg, _efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt)) u_bkt, v_bkt, eid_bkt))
# vars # vars
var_u = var.IDX(u_bkt) var_u = var.IDX(u_bkt)
...@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets): ...@@ -280,7 +274,7 @@ def _process_edge_buckets(buckets):
eids = split(eids) eids = split(eids)
return degs, uids, vids, eids return degs, uids, vids, eids
def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid): def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
"""Internal function to generate the per degree bucket edge UDF.""" """Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data): def _efunc_wrapper(src_data, edge_data, dst_data):
...@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid): ...@@ -302,7 +296,7 @@ def _create_per_bkt_efunc(graph, apply_func, deg, u, v, eid):
edge_data.keys()) edge_data.keys())
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data), reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys()) dst_data.keys())
ebatch = EdgeBatch(graph, (u, v, eid), reshaped_src_data, ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data) reshaped_edge_data, reshaped_dst_data)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()} return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper return _efunc_wrapper
......
This diff is collapsed.
...@@ -6,8 +6,7 @@ from ..base import DGLError ...@@ -6,8 +6,7 @@ from ..base import DGLError
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .. import ndarray as nd from .. import ndarray as nd
from ..graph_index import GraphIndex from ..heterograph_index import create_unitgraph_from_coo
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph): ...@@ -129,8 +128,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters Parameters
---------- ----------
graph : DGLGraph or DGLHeteroGraph graph : GraphAdapter
The homogeneous graph, or a bipartite view of the heterogeneous graph. Graph
Returns Returns
------- -------
...@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph): ...@@ -142,30 +141,21 @@ def build_gidx_and_mapping_graph(graph):
nbits : int nbits : int
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
gidx = graph._graph return graph.get_immutable_gidx, None, graph.bits_needed()
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges """Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of The matrix is of shape (num_src, num_dst).
nodes in the graph. Therefore, when doing SPMV, the src node data should be
all the node features.
Parameters Parameters
--------- ---------
edge_tuples : tuple of three utils.Index edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid) A tuple of (u, v, eid)
num_src, num_dst : int num_src : int
The number of source and destination nodes. Number of source nodes.
num_dst : int
Number of destination nodes.
Returns Returns
------- -------
...@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): ...@@ -178,7 +168,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
gidx = create_bipartite_from_coo(num_src, num_dst, u, v) gidx = create_unitgraph_from_coo(2, num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0) forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor() eid = eid.tousertensor()
nbits = gidx.bits_needed(0) nbits = gidx.bits_needed(0)
...@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst): ...@@ -189,8 +179,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
edge_map = utils.CtxCachedObject( edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx), lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx))) nd.array(backward_map, ctx=ctx)))
return partial(gidx.get_bipartite, None), edge_map, nbits return partial(gidx.get_unitgraph, 0), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
"""Build immutable graph index and mapping for node flow """Build immutable graph index and mapping for node flow
......
"""User-defined function related data structures.""" """User-defined function related data structures."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import is_all
from . import backend as F
from . import utils
class EdgeBatch(object): class EdgeBatch(object):
"""The class that can represent a batch of edges. """The class that can represent a batch of edges.
Parameters Parameters
---------- ----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL The edge tuple (u, v, eid). eid can be ALL
src_data : dict src_data : dict
...@@ -24,8 +18,7 @@ class EdgeBatch(object): ...@@ -24,8 +18,7 @@ class EdgeBatch(object):
The dst node features, in the form of ``dict`` The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, edges, src_data, edge_data, dst_data): def __init__(self, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges self._edges = edges
self._src_data = src_data self._src_data = src_data
self._edge_data = edge_data self._edge_data = edge_data
...@@ -75,9 +68,6 @@ class EdgeBatch(object): ...@@ -75,9 +68,6 @@ class EdgeBatch(object):
destination node and the edge id for the ith edge destination node and the edge id for the ith edge
in the batch. in the batch.
""" """
if is_all(self._edges[2]):
self._edges = self._edges[:2] + (utils.toindex(F.arange(
0, self._g.number_of_edges())),)
u, v, eid = self._edges u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor()) return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
...@@ -104,9 +94,7 @@ class NodeBatch(object): ...@@ -104,9 +94,7 @@ class NodeBatch(object):
Parameters Parameters
---------- ----------
g : DGLGraph nodes : utils.Index
The graph object.
nodes : utils.Index or ALL
The node ids. The node ids.
data : dict data : dict
The node features, in the form of ``dict`` The node features, in the form of ``dict``
...@@ -115,8 +103,7 @@ class NodeBatch(object): ...@@ -115,8 +103,7 @@ class NodeBatch(object):
The messages, , in the form of ``dict`` The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
""" """
def __init__(self, g, nodes, data, msgs=None): def __init__(self, nodes, data, msgs=None):
self._g = g
self._nodes = nodes self._nodes = nodes
self._data = data self._data = data
self._msgs = msgs self._msgs = msgs
...@@ -154,9 +141,6 @@ class NodeBatch(object): ...@@ -154,9 +141,6 @@ class NodeBatch(object):
tensor tensor
The nodes. The nodes.
""" """
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes()))
return self._nodes.tousertensor() return self._nodes.tousertensor()
def batch_size(self): def batch_size(self):
...@@ -166,9 +150,6 @@ class NodeBatch(object): ...@@ -166,9 +150,6 @@ class NodeBatch(object):
------- -------
int int
""" """
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes) return len(self._nodes)
def __len__(self): def __len__(self):
......
...@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits): ...@@ -505,3 +505,14 @@ def to_nbits_int(tensor, nbits):
return F.astype(tensor, F.int32) return F.astype(tensor, F.int32)
else: else:
return F.astype(tensor, F.int64) return F.astype(tensor, F.int64)
def make_invmap(array, use_numpy=True):
"""Find the unique elements of the array and return another array with indices
to the array of unique elements."""
if use_numpy:
uniques = np.unique(array)
else:
uniques = list(set(array))
invmap = {x: i for i, x in enumerate(uniques)}
remapped = np.array([invmap[x] for x in array])
return uniques, invmap, remapped
...@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError ...@@ -10,6 +10,7 @@ from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class NodeView(object): class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph. """A NodeView class to act as G.nodes for a DGLGraph.
...@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping): ...@@ -79,8 +80,6 @@ class NodeDataView(MutableMapping):
data = self._graph.get_n_repr(self._nodes) data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame}) return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object): class EdgeView(object):
"""A EdgeView class to act as G.edges for a DGLGraph. """A EdgeView class to act as G.edges for a DGLGraph.
...@@ -256,111 +255,57 @@ class HeteroNodeView(object): ...@@ -256,111 +255,57 @@ class HeteroNodeView(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
def __getitem__(self, ntype): def __getitem__(self, key):
return HeteroNodeTypeView(self._graph, ntype) if isinstance(key, slice):
class HeteroNodeTypeView(object):
"""A NodeView class to act as G.nodes[ntype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.nodes
"""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __len__(self):
return self._graph.number_of_nodes(self._graph._ntypes_invmap[self._ntype])
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice # slice
if not (nodes.start is None and nodes.stop is None if not (key.start is None and key.stop is None
and nodes.step is None): and key.step is None):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, ALL)) nodes = ALL
ntype = None
elif isinstance(key, tuple):
nodes, ntype = key
elif isinstance(key, str):
nodes = ALL
ntype = key
else: else:
return NodeSpace(data=HeteroNodeTypeDataView(self._graph, self._ntype, nodes)) nodes = key
ntype = None
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, nodes))
def __call__(self): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
return F.arange(0, len(self)) return F.arange(0, self._graph.number_of_nodes(ntype))
class HeteroNodeTypeDataView(MutableMapping):
"""The data view class when G.nodes[ntype][...].data is called.
See Also class HeteroNodeDataView(MutableMapping):
-------- """The data view class when G.ndata[ntype] is called."""
dgl.DGLGraph.nodes __slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
"""
__slots__ = ['_graph', '_ntype', '_nodes']
def __init__(self, graph, ntype, nodes): def __init__(self, graph, ntype, nodes):
self._graph = graph self._graph = graph
self._ntype = ntype self._ntype = ntype
self._ntid = self._graph.get_ntype_id(ntype)
self._nodes = nodes self._nodes = nodes
def __getitem__(self, key): def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype, self._nodes)[key] return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}, self._nodes)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]])
def __iter__(self):
return iter(self._graph.get_n_repr(self._ntype, self._nodes))
def __repr__(self):
data = self._graph.get_n_repr(self._ntype, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]})
class HeteroNodeDataView(object):
"""The data view class when G.ndata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroNodeDataTypeView(self._graph, key)
class HeteroNodeDataTypeView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype']
def __init__(self, graph, ntype):
self._graph = graph
self._ntype = ntype
def __getitem__(self, key):
return self._graph.get_n_repr(self._ntype)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph.set_n_repr(self._ntype, {key : val}) self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph.pop_n_repr(self._ntype, key) self._graph._pop_n_repr(self._ntid, key)
def __len__(self): def __len__(self):
return len(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]) return len(self._graph._node_frames[self._ntid])
def __iter__(self): def __iter__(self):
return iter(self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]) return iter(self._graph._node_frames[self._ntid])
def __repr__(self): def __repr__(self):
data = self._graph.get_n_repr(self._ntype) data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key] return repr({key : data[key]
for key in self._graph._node_frames[self._graph._ntypes_invmap[self._ntype]]}) for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
...@@ -369,108 +314,59 @@ class HeteroEdgeView(object): ...@@ -369,108 +314,59 @@ class HeteroEdgeView(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
def __getitem__(self, etype): def __getitem__(self, key):
return HeteroEdgeTypeView(self._graph, etype) if isinstance(key, slice):
class HeteroEdgeTypeView(object):
"""A EdgeView class to act as G.edges[etype] for a DGLHeteroGraph.
See Also
--------
dgl.DGLGraph.edges
"""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __len__(self):
return self._graph.number_of_edges(self._graph._etypes_invmap[self._etype])
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice # slice
if not (edges.start is None and edges.stop is None if not (key.start is None and key.stop is None
and edges.step is None): and key.step is None):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, ALL)) edges = ALL
etype = None
elif isinstance(key, tuple):
if len(key) == 3:
edges = ALL
etype = key
else: else:
return EdgeSpace(data=HeteroEdgeTypeDataView(self._graph, self._etype, edges)) edges = key
etype = None
def __call__(self): elif isinstance(key, (str, tuple)):
"""Return the edges.""" edges = ALL
return F.arange(0, len(self)) etype = key
else:
edges = key
etype = None
return EdgeSpace(data=HeteroEdgeDataView(self._graph, etype, edges))
class HeteroEdgeTypeDataView(MutableMapping): def __call__(self, *args, **kwargs):
"""The data view class when G.edges[etype][...].data is called. """Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
See Also class HeteroEdgeDataView(MutableMapping):
-------- """The data view class when G.ndata[etype] is called."""
dgl.DGLGraph.edges __slots__ = ['_graph', '_etype', '_etid', '_edges']
"""
__slots__ = ['_graph', '_etype', '_edges']
def __init__(self, graph, etype, edges): def __init__(self, graph, etype, edges):
self._graph = graph self._graph = graph
self._etype = etype self._etype = etype
self._etid = self._graph.get_etype_id(etype)
self._edges = edges self._edges = edges
def __getitem__(self, key): def __getitem__(self, key):
return self._graph.get_e_repr(self._etype, self._edges)[key] return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}, self._edges)
def __delitem__(self, key):
raise DGLError('Delete feature data is not supported on only a subset'
' of edges. Please use `del G.edata[key]` instead.')
def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]])
def __iter__(self):
return iter(self._graph.get_e_repr(self._etype, self._edges))
def __repr__(self):
data = self._graph.get_e_repr(self._etype, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]})
class HeteroEdgeDataView(object):
"""The data view class when G.edata is called."""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __getitem__(self, key):
return HeteroEdgeDataTypeView(self._graph, key)
class HeteroEdgeDataTypeView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype']
def __init__(self, graph, etype):
self._graph = graph
self._etype = etype
def __getitem__(self, key):
return self._graph.get_e_repr(self._etype)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph.set_e_repr(self._etype, {key : val}) self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph.pop_e_repr(self._etype, key) self._graph._pop_e_repr(self._etid, key)
def __len__(self): def __len__(self):
return len(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]) return len(self._graph._edge_frames[self._etid])
def __iter__(self): def __iter__(self):
return iter(self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]) return iter(self._graph._edge_frames[self._etid])
def __repr__(self): def __repr__(self):
data = self._graph.get_e_repr(self._etype) data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key] return repr({key : data[key]
for key in self._graph._edge_frames[self._graph._etypes_invmap[self._etype]]}) for key in self._graph._edge_frames[self._etid]})
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* \brief Heterograph implementation * \brief Heterograph implementation
*/ */
#include "./heterograph.h" #include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./bipartite.h" #include "./unit_graph.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -50,7 +51,7 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// following heterograph: // following heterograph:
// //
// Meta graph: A -> B -> C // Meta graph: A -> B -> C
// Bipartite graphs: // UnitGraph graphs:
// * A -> B: (0, 0), (0, 1) // * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1) // * B -> C: (1, 0), (1, 1)
// //
...@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -91,7 +92,8 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype); auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
subrels[etype] = Bipartite::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0], ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0], ret.induced_vertices[dst_vtype]->shape[0],
subedges[etype].src, subedges[etype].src,
...@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -108,10 +110,9 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// Sanity check // Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graph must be bipartite graphs // all relation graphs must have only one edge type
for (const auto rg : rel_graphs) { for (const auto rg : rel_graphs) {
CHECK_EQ(rg->NumVertexTypes(), 2) << "Each relation graph must be a bipartite graph."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must be a bipartite graph.";
} }
// create num verts per type // create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1); num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
...@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -125,17 +126,20 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
dgl_type_t srctype = srctypes[i]; dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i]; dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i]; dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype];
const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
size_t nv; size_t nv;
// # nodes of source type // # nodes of source type
nv = rel_graphs[etype]->NumVertices(Bipartite::kSrcVType); nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0) if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv; num_verts_per_type_[srctype] = nv;
else else
CHECK_EQ(num_verts_per_type_[srctype], nv) CHECK_EQ(num_verts_per_type_[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype; << "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type // # nodes of destination type
nv = rel_graphs[etype]->NumVertices(Bipartite::kDstVType); nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0) if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv; num_verts_per_type_[dsttype] = nv;
else else
...@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con ...@@ -171,8 +175,10 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype); auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph( const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
{vids[src_vtype], vids[dst_vtype]}); std::vector<IdArray>({vids[src_vtype]}) :
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph; subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0]; ret.induced_edges[etype] = rel_vsg.induced_edges[0];
} }
...@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph( ...@@ -189,18 +195,106 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
} }
} }
// creator implementation FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const {
HeteroGraphPtr CreateBipartiteFromCOO( std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
int64_t num_src, int64_t num_dst, IdArray row, IdArray col) { size_t src_nodes = 0, dst_nodes = 0;
return Bipartite::CreateFromCOO(num_src, num_dst, row, col); std::vector<dgl_id_t> result_src, result_dst;
} std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
// loop.
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t num_srctype_nodes = NumVertices(srctype);
size_t num_dsttype_nodes = NumVertices(dsttype);
if (srctype_offsets.count(srctype) == 0) {
srctype_offsets[srctype] = num_srctype_nodes;
srctype_set.push_back(srctype);
}
if (dsttype_offsets.count(dsttype) == 0) {
dsttype_offsets[dsttype] = num_dsttype_nodes;
dsttype_set.push_back(dsttype);
}
}
HeteroGraphPtr CreateBipartiteFromCSR( // Sort the node types so that we can compare the sets and decide whether a homograph
int64_t num_src, int64_t num_dst, // should be returned.
IdArray indptr, IdArray indices, IdArray edge_ids) { std::sort(srctype_set.begin(), srctype_set.end());
return Bipartite::CreateFromCSR(num_src, num_dst, indptr, indices, edge_ids); std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these
// two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype];
srctype_offsets[ntype] = src_nodes;
src_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_srctype.push_back(ntype);
induced_srcid.push_back(j);
}
}
for (size_t i = 0; i < dsttype_set.size(); ++i) {
dgl_type_t ntype = dsttype_set[i];
size_t num_nodes = dsttype_offsets[ntype];
dsttype_offsets[ntype] = dst_nodes;
dst_nodes += num_nodes;
for (size_t j = 0; j < num_nodes; ++j) {
induced_dsttype.push_back(ntype);
induced_dstid.push_back(j);
}
}
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
dgl_type_t dsttype = src_dsttype.second;
size_t srctype_offset = srctype_offsets[srctype];
size_t dsttype_offset = dsttype_offsets[dsttype];
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
result_dst.push_back(edges_dst_data[i] + dsttype_offset);
induced_etype.push_back(etype);
induced_eid.push_back(edges_eid_data[i]);
}
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->induced_srctype = aten::VecToIdArray(induced_srctype);
result->induced_srctype_set = aten::VecToIdArray(srctype_set);
result->induced_srcid = aten::VecToIdArray(induced_srcid);
result->induced_etype = aten::VecToIdArray(induced_etype);
result->induced_etype_set = aten::VecToIdArray(etypes);
result->induced_eid = aten::VecToIdArray(induced_eid);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype);
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set);
result->induced_dstid = aten::VecToIdArray(induced_dstid);
return FlattenedHeteroGraphPtr(result);
} }
// creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs)); return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
...@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -208,24 +302,27 @@ HeteroGraphPtr CreateHeteroGraph(
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCOO") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t nvtypes = args[0];
int64_t num_dst = args[1]; int64_t num_src = args[1];
IdArray row = args[2]; int64_t num_dst = args[2];
IdArray col = args[3]; IdArray row = args[3];
auto hgptr = CreateBipartiteFromCOO(num_src, num_dst, row, col); IdArray col = args[4];
auto hgptr = UnitGraph::CreateFromCOO(nvtypes, num_src, num_dst, row, col);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateBipartiteFromCSR") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t num_src = args[0]; int64_t nvtypes = args[0];
int64_t num_dst = args[1]; int64_t num_src = args[1];
IdArray indptr = args[2]; int64_t num_dst = args[2];
IdArray indices = args[3]; IdArray indptr = args[3];
IdArray edge_ids = args[4]; IdArray indices = args[4];
auto hgptr = CreateBipartiteFromCSR(num_src, num_dst, indptr, indices, edge_ids); IdArray edge_ids = args[5];
auto hgptr = UnitGraph::CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
...@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") ...@@ -252,7 +349,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
if (hg->NumEdgeTypes() == 1) {
CHECK_EQ(etype, 0);
*rv = hg;
} else {
*rv = HeteroGraphRef(hg->GetRelationGraph(etype)); *rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes)
etypes_vec.push_back(val->data);
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
...@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") ...@@ -551,7 +664,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int bits = args[1]; int bits = args[1];
HeteroGraphPtr hg_new = Bipartite::AsNumBits(hg.sptr(), bits); HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
...@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -563,7 +676,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
HeteroGraphPtr hg_new = Bipartite::CopyTo(hg.sptr(), ctx); HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
......
...@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -20,14 +20,6 @@ class HeteroGraph : public BaseHeteroGraph {
public: public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs); HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
uint64_t NumVertexTypes() const override {
return meta_graph_->NumVertices();
}
uint64_t NumEdgeTypes() const override {
return meta_graph_->NumEdges();
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype; CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
return relation_graphs_[etype]; return relation_graphs_[etype];
...@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -172,8 +164,10 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
private: private:
/*! \brief A map from edge type to bipartite graph */ /*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_; std::vector<HeteroGraphPtr> relation_graphs_;
/*! \brief A map from vert type to the number of verts in the type */ /*! \brief A map from vert type to the number of verts in the type */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment