Unverified Commit ce6e19f2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Allows node types with no inbound/outbound edges (#1323)

* add num nodes in ctors

* fix

* lint

* addresses comments

* replace with constexpr

* remove function with rvalue reference

* address comments
parent ac74233c
......@@ -125,6 +125,12 @@ class BaseHeteroGraph : public runtime::Object {
/*! \return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;
/*! \return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {};
}
/*! \return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0;
......@@ -543,9 +549,14 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Declarations of functions and algorithms
/*! \brief Create a heterograph from meta graph and a list of bipartite graph */
/*!
* \brief Create a heterograph from meta graph and a list of bipartite graph,
* additionally specifying number of nodes per type.
*/
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr> &rel_graphs,
const std::vector<int64_t> &num_nodes_per_type = {});
/*!
* \brief Create a heterograph from COO input.
......@@ -651,6 +662,9 @@ struct HeteroPickleStates : public runtime::Object {
/*! \brief Metagraph. */
GraphPtr metagraph;
/*! \brief Number of nodes per type */
std::vector<int64_t> num_nodes_per_type;
/*! \brief adjacency matrices of each relation graph */
std::vector<std::shared_ptr<SparseMatrix> > adjs;
......
......@@ -14,8 +14,45 @@
#include "serializer.h"
#include "shared_mem.h"
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
namespace dgl {
/*!
* \brief Type traits that converts a C type to a DLDataType.
*
* Usage:
* DLDataTypeTraits<int>::dtype == dtype
*/
template<typename T>
struct DLDataTypeTraits {
static constexpr DLDataType dtype{0, 0, 0}; // dummy
};
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \
template<> \
struct DLDataTypeTraits<T> { \
static constexpr DLDataType dtype{code, bits, 1}; \
}
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR
namespace runtime {
/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
......@@ -191,8 +228,14 @@ class NDArray {
DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});
/*!
* \brief Create a std::vector from a 1D NDArray.
* \tparam T Type of vector data.
* \note Type casting is NOT performed. The caller has to make sure that the vector
* type matches the dtype of NDArray.
*/
template<typename T>
static NDArray FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx);
std::vector<T> ToVector() const;
/*!
* \brief Function to copy data from one array to another.
......
......@@ -281,7 +281,7 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
else:
raise DGLError('Unsupported graph data type:', type(data))
def hetero_from_relations(rel_graphs):
def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
"""Create a heterograph from graphs representing connections of each relation.
The input is a list of heterographs where the ``i``th graph contains edges of type
......@@ -294,6 +294,9 @@ def hetero_from_relations(rel_graphs):
----------
rel_graphs : list of DGLHeteroGraph
Each element corresponds to a heterograph for one (src, edge, dst) relation.
num_nodes_per_type : dict[str, Tensor], optional
Number of nodes per node type. If not given, DGL will infer the number of nodes
from the given relation graphs.
Returns
-------
......@@ -349,32 +352,35 @@ def hetero_from_relations(rel_graphs):
# TODO(minjie): this API can be generalized as a union operation of the input graphs
# TODO(minjie): handle node/edge data
# infer meta graph
ntype_set = set()
meta_edges = []
meta_edges_src, meta_edges_dst = [], []
ntypes = []
etypes = []
# TODO(BarclayII): I'm keeping the node type names sorted because even if
# the metagraph is the same, the same node type name in different graphs may
# map to different node type IDs.
# In the future, we need to lower the type names into C++.
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
ntype_set.add(stype)
ntype_set.add(dtype)
ntypes = list(sorted(ntype_set))
if num_nodes_per_type is None:
ntype_set = set()
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
ntype_set.add(stype)
ntype_set.add(dtype)
ntypes = list(sorted(ntype_set))
else:
ntypes = list(sorted(num_nodes_per_type.keys()))
num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes])
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
for rgrh in rel_graphs:
stype, etype, dtype = rgrh.canonical_etypes[0]
stid = ntype_dict[stype]
dtid = ntype_dict[dtype]
meta_edges.append((stid, dtid))
meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype)
metagraph = graph_index.from_edge_list(meta_edges, True, True)
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True, True)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs])
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
retg = DGLHeteroGraph(hgidx, ntypes, etypes)
for i, rgrh in enumerate(rel_graphs):
for ntype in rgrh.ntypes:
......@@ -441,8 +447,12 @@ def heterograph(data_dict, num_nodes_dict=None):
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc
elif isinstance(data, DGLHeteroGraph):
# Do nothing; handled in the next loop
continue
# original node type and edge type of ``data`` is ignored.
assert len(data.canonical_etypes) == 1, \
"Relational graphs must have only one edge type."
srctype, _, dsttype = data.canonical_etypes[0]
nsrc = data.number_of_nodes(srctype)
ndst = data.number_of_nodes(dsttype)
else:
raise DGLError('Unsupported graph data type %s for %s' % (
type(data), (srctype, etype, dsttype)))
......@@ -464,7 +474,7 @@ def heterograph(data_dict, num_nodes_dict=None):
data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))
return hetero_from_relations(rel_graphs)
return hetero_from_relations(rel_graphs, num_nodes_dict)
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given homogeneous graph to a heterogeneous graph.
......@@ -622,7 +632,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
card=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs)
hg = hetero_from_relations(
rel_graphs, {ntype: count for ntype, count in zip(ntypes, ntype_count)})
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes):
......
......@@ -1684,6 +1684,7 @@ class DGLHeteroGraph(object):
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = []
num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
for i in range(len(self._etypes)):
......@@ -1697,7 +1698,8 @@ class DGLHeteroGraph(object):
edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
return hg
......@@ -1767,9 +1769,11 @@ class DGLHeteroGraph(object):
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
......
......@@ -1008,7 +1008,7 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor(),
restrict_format)
def create_heterograph_from_relations(metagraph, rel_graphs):
def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
......@@ -1017,12 +1017,18 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
num_nodes_per_type : utils.Index, optional
Number of nodes per node type
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
if num_nodes_per_type is None:
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
else:
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor())
def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs.
......@@ -1085,6 +1091,17 @@ class HeteroPickleStates(ObjectBase):
"""
return _CAPI_DGLHeteroPickleStatesGetMetagraph(self)
@property
def num_nodes_per_type(self):
"""Number of nodes per edge type
Returns
-------
Tensor
Array of number of nodes for each type
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLHeteroPickleStatesGetNumVertices(self))
@property
def adjs(self):
"""Adjacency matrices of all the relation graphs
......@@ -1097,11 +1114,12 @@ class HeteroPickleStates(ObjectBase):
return list(_CAPI_DGLHeteroPickleStatesGetAdjs(self))
def __getstate__(self):
return self.metagraph, self.adjs
return self.metagraph, self.num_nodes_per_type, self.adjs
def __setstate__(self, state):
metagraph, adjs = state
metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, metagraph, adjs)
_CAPI_DGLCreateHeteroPickleStates, metagraph, num_nodes_per_type, adjs)
_init_api("dgl.heterograph_index")
......@@ -84,6 +84,10 @@ class IdHashMap {
return values;
}
inline size_t Size() const {
return oldv2newv_.size();
}
private:
static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1;
......
......@@ -186,7 +186,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
}
}
}
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
......@@ -228,7 +228,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
}
}
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
......@@ -306,9 +306,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
}
}
return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)};
return {NDArray::FromVector(ret_rows, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->ctx)};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
......@@ -536,8 +536,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols,
NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(sub_indptr, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->ctx),
sub_data_arr};
}
......
......@@ -16,16 +16,6 @@
using dgl::runtime::operator<<;
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << ctx.device_type << ":" << ctx.device_id;
......
......@@ -10,8 +10,10 @@ namespace dgl {
// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
}
HeteroGraphPtr CreateFromCOO(
......
......@@ -37,7 +37,8 @@ HeteroSubgraph EdgeSubgraphPreserveNodes(
ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
}
ret.graph = HeteroGraphPtr(new HeteroGraph(hg->meta_graph(), subrels));
ret.graph = HeteroGraphPtr(new HeteroGraph(
hg->meta_graph(), subrels, hg->NumVerticesPerType()));
return ret;
}
......@@ -86,8 +87,10 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
subedges[etype] = earray;
}
// step (3)
std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
}
// step (4)
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
......@@ -102,14 +105,12 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
subedges[etype].src,
subedges[etype].dst);
}
ret.graph = HeteroGraphPtr(new HeteroGraph(hg->meta_graph(), subrels));
ret.graph = HeteroGraphPtr(new HeteroGraph(
hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
return ret;
}
} // namespace
HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs)
: BaseHeteroGraph(meta_graph) {
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
......@@ -117,8 +118,12 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
}
}
std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1);
std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
EdgeArray etype_array = meta_graph->Edges();
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
......@@ -136,30 +141,49 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// # nodes of source type
nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0)
num_verts_per_type_[srctype] = nv;
if (num_verts_per_type[srctype] < 0)
num_verts_per_type[srctype] = nv;
else
CHECK_EQ(num_verts_per_type_[srctype], nv)
CHECK_EQ(num_verts_per_type[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type
nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0)
num_verts_per_type_[dsttype] = nv;
if (num_verts_per_type[dsttype] < 0)
num_verts_per_type[dsttype] = nv;
else
CHECK_EQ(num_verts_per_type_[dsttype], nv)
CHECK_EQ(num_verts_per_type[dsttype], nv)
<< "Mismatch number of vertices for vertex type " << dsttype;
}
return num_verts_per_type;
}
relation_graphs_.resize(rel_graphs.size());
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
for (size_t i = 0; i < rel_graphs.size(); ++i) {
HeteroGraphPtr relg = rel_graphs[i];
if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
relation_graphs_[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
} else {
relation_graphs_[i] = CHECK_NOTNULL(
relation_graphs[i] = CHECK_NOTNULL(
std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
}
}
return relation_graphs;
}
} // namespace
HeteroGraph::HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
if (num_nodes_per_type.size() == 0)
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else
num_verts_per_type_ = num_nodes_per_type;
HeteroGraphSanityCheck(meta_graph, rel_graphs);
relation_graphs_ = CastToUnitGraphs(rel_graphs);
}
bool HeteroGraph::IsMultigraph() const {
......@@ -183,6 +207,9 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
<< "Invalid input: the input list size must be the same as the number of vertex types.";
HeteroSubgraph ret;
ret.induced_vertices = vids;
std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
num_vertices_per_type[vtype] = vids[vtype]->shape[0];
ret.induced_edges.resize(NumEdgeTypes());
std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
......@@ -196,7 +223,8 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0];
}
ret.graph = HeteroGraphPtr(new HeteroGraph(meta_graph_, subrels));
ret.graph = HeteroGraphPtr(new HeteroGraph(
meta_graph_, subrels, std::move(num_vertices_per_type)));
return ret;
}
......
......@@ -19,7 +19,10 @@ namespace dgl {
/*! \brief Heterograph */
class HeteroGraph : public BaseHeteroGraph {
public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {});
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
......@@ -65,6 +68,10 @@ class HeteroGraph : public BaseHeteroGraph {
return num_verts_per_type_[vtype];
}
inline std::vector<int64_t> NumVerticesPerType() const override {
return num_verts_per_type_;
}
uint64_t NumEdges(dgl_type_t etype) const override {
return GetRelationGraph(etype)->NumEdges(0);
}
......
......@@ -55,6 +55,21 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
IdArray num_nodes_per_type = args[2];
std::vector<HeteroGraphPtr> rel_ptrs;
rel_ptrs.reserve(rel_graphs.size());
for (const auto& ref : rel_graphs) {
rel_ptrs.push_back(ref.sptr());
}
auto hgptr = CreateHeteroGraph(
meta_graph.sptr(), rel_ptrs, num_nodes_per_type.ToVector<int64_t>());
*rv = HeteroGraphRef(hgptr);
});
///////////////////////// HeteroGraph member functions /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
......
......@@ -15,6 +15,7 @@ namespace dgl {
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates states;
states.metagraph = graph->meta_graph();
states.num_nodes_per_type = graph->NumVerticesPerType();
states.adjs.resize(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::ANY);
......@@ -36,6 +37,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const auto metagraph = states.metagraph;
const auto &num_nodes_per_type = states.num_nodes_per_type;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
......@@ -58,7 +60,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
LOG(FATAL) << "Unsupported sparse format.";
}
}
return CreateHeteroGraph(metagraph, relgraphs);
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
......@@ -67,6 +69,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
*rv = GraphRef(st->metagraph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
*rv = NDArray::FromVector(st->num_nodes_per_type);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
......@@ -77,9 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
List<SparseMatrixRef> adjs = args[1];
IdArray num_nodes_per_type = args[1];
List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->metagraph = metagraph.sptr();
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
st->adjs.reserve(adjs.size());
for (const auto& ref : adjs)
st->adjs.push_back(ref.sptr());
......
......@@ -86,7 +86,7 @@ HeteroSubgraph SampleNeighbors(
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels);
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
......@@ -160,7 +160,7 @@ HeteroSubgraph SampleNeighborsTopk(
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels);
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
......
......@@ -38,7 +38,7 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges);
return ret;
}
......@@ -73,7 +73,7 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges);
return ret;
}
......
......@@ -32,7 +32,8 @@ CompactGraphs(
const std::vector<IdArray> &always_preserve) {
// TODO(BarclayII): check whether the node space and metagraph of each graph is the same.
// Step 1: Collect the nodes that has connections for each type.
std::vector<aten::IdHashMap<IdType>> hashmaps(graphs[0]->NumVertexTypes());
const int64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype]
for (size_t i = 0; i < always_preserve.size(); ++i)
......@@ -56,9 +57,12 @@ CompactGraphs(
}
// Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping.
std::vector<IdArray> induced_nodes;
for (auto &hashmap : hashmaps)
induced_nodes.push_back(hashmap.Values());
std::vector<IdArray> induced_nodes(num_ntypes);
std::vector<int64_t> num_induced_nodes(num_ntypes);
for (int64_t i = 0; i < num_ntypes; ++i) {
induced_nodes[i] = hashmaps[i].Values();
num_induced_nodes[i] = hashmaps[i].Size();
}
// Step 3: Remap the edges of each graph.
std::vector<HeteroGraphPtr> new_graphs;
......@@ -84,7 +88,7 @@ CompactGraphs(
mapped_cols));
}
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs));
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
}
return std::make_pair(new_graphs, induced_nodes);
......
......@@ -71,7 +71,8 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
coalesced_adj.col);
}
const HeteroGraphPtr result = CreateHeteroGraph(metagraph, rel_graphs);
const HeteroGraphPtr result = CreateHeteroGraph(
metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps);
}
......
......@@ -12,6 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices());
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
......@@ -46,8 +47,10 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
}
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
......@@ -121,8 +124,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
}
std::vector<HeteroGraphPtr> rst;
std::vector<int64_t> num_nodes_per_type(num_vertex_types);
for (uint64_t g = 0; g < batch_size; ++g) {
rst.push_back(HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs[g])));
for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
}
return rst;
}
......
......@@ -92,6 +92,11 @@ class UnitGraph : public BaseHeteroGraph {
uint64_t NumVertices(dgl_type_t vtype) const override;
inline std::vector<int64_t> NumVerticesPerType() const override {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on unit graphs.";
return {};
}
uint64_t NumEdges(dgl_type_t etype) const override;
bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;
......
......@@ -224,7 +224,8 @@ void NDArray::CopyFromTo(DLTensor* from,
}
template<typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx) {
NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) {
const DLDataType dtype = DLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, DLContext{kDLCPU, 0});
DeviceAPI::Get(ctx)->CopyDataFromTo(
......@@ -241,27 +242,41 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLDataType dtype, DLConte
}
// export specializations
template NDArray NDArray::FromVector(const std::vector<int32_t>&, DLDataType, DLContext);
template NDArray NDArray::FromVector(const std::vector<int64_t>&, DLDataType, DLContext);
template NDArray NDArray::FromVector(const std::vector<uint32_t>&, DLDataType, DLContext);
template NDArray NDArray::FromVector(const std::vector<uint64_t>&, DLDataType, DLContext);
template NDArray NDArray::FromVector(const std::vector<float>&, DLDataType, DLContext);
template NDArray NDArray::FromVector(const std::vector<double>&, DLDataType, DLContext);
// specializations of FromVector
#define GEN_FROMVECTOR_FOR(T, DTypeCode, DTypeBits) \
template<> \
NDArray NDArray::FromVector<T>(const std::vector<T> &vec, DLContext ctx) { \
return FromVector<T>(vec, DLDataType{DTypeCode, DTypeBits, 1}, ctx); \
}
GEN_FROMVECTOR_FOR(int32_t, kDLInt, 32);
GEN_FROMVECTOR_FOR(int64_t, kDLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed NDArrays.
GEN_FROMVECTOR_FOR(uint32_t, kDLInt, 32);
GEN_FROMVECTOR_FOR(uint64_t, kDLInt, 64);
GEN_FROMVECTOR_FOR(float, kDLFloat, 32);
GEN_FROMVECTOR_FOR(double, kDLFloat, 64);
template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DLContext);
template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DLContext);
template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DLContext);
template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DLContext);
template NDArray NDArray::FromVector<float>(const std::vector<float>&, DLContext);
template NDArray NDArray::FromVector<double>(const std::vector<double>&, DLContext);
template<typename T>
std::vector<T> NDArray::ToVector() const {
const DLDataType dtype = DLDataTypeTraits<T>::dtype;
CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
int64_t size = data_->dl_tensor.shape[0];
std::vector<T> vec(size);
const DLContext &ctx = data_->dl_tensor.ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo(
static_cast<T*>(data_->dl_tensor.data),
0,
vec.data(),
0,
size * sizeof(T),
ctx,
DLContext{kDLCPU, 0},
dtype,
nullptr);
return vec;
}
template std::vector<int32_t> NDArray::ToVector<int32_t>() const;
template std::vector<int64_t> NDArray::ToVector<int64_t>() const;
template std::vector<uint32_t> NDArray::ToVector<uint32_t>() const;
template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const;
} // namespace runtime
} // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment