Unverified Commit ce6e19f2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Allows node types with no inbound/outbound edges (#1323)

* add num nodes in ctors

* fix

* lint

* addresses comments

* replace with constexpr

* remove function with rvalue reference

* address comments
parent ac74233c
...@@ -125,6 +125,12 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -125,6 +125,12 @@ class BaseHeteroGraph : public runtime::Object {
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0; virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;
/*! \return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {};
}
/*! \return the number of edges in the graph.*/ /*! \return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0; virtual uint64_t NumEdges(dgl_type_t etype) const = 0;
...@@ -543,9 +549,14 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph); ...@@ -543,9 +549,14 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Declarations of functions and algorithms // Declarations of functions and algorithms
/*! \brief Create a heterograph from meta graph and a list of bipartite graph */ /*!
* \brief Create a heterograph from meta graph and a list of bipartite graph,
* additionally specifying number of nodes per type.
*/
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs); GraphPtr meta_graph,
const std::vector<HeteroGraphPtr> &rel_graphs,
const std::vector<int64_t> &num_nodes_per_type = {});
/*! /*!
* \brief Create a heterograph from COO input. * \brief Create a heterograph from COO input.
...@@ -651,6 +662,9 @@ struct HeteroPickleStates : public runtime::Object { ...@@ -651,6 +662,9 @@ struct HeteroPickleStates : public runtime::Object {
/*! \brief Metagraph. */ /*! \brief Metagraph. */
GraphPtr metagraph; GraphPtr metagraph;
/*! \brief Number of nodes per type */
std::vector<int64_t> num_nodes_per_type;
/*! \brief adjacency matrices of each relation graph */ /*! \brief adjacency matrices of each relation graph */
std::vector<std::shared_ptr<SparseMatrix> > adjs; std::vector<std::shared_ptr<SparseMatrix> > adjs;
......
...@@ -14,8 +14,45 @@ ...@@ -14,8 +14,45 @@
#include "serializer.h" #include "serializer.h"
#include "shared_mem.h" #include "shared_mem.h"
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
namespace dgl { namespace dgl {
/*!
* \brief Type traits that converts a C type to a DLDataType.
*
* Usage:
* DLDataTypeTraits<int>::dtype == dtype
*/
template<typename T>
struct DLDataTypeTraits {
static constexpr DLDataType dtype{0, 0, 0}; // dummy
};
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \
template<> \
struct DLDataTypeTraits<T> { \
static constexpr DLDataType dtype{code, bits, 1}; \
}
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR
namespace runtime { namespace runtime {
/*! /*!
* \brief Managed NDArray. * \brief Managed NDArray.
* The array is backed by reference counted blocks. * The array is backed by reference counted blocks.
...@@ -191,8 +228,14 @@ class NDArray { ...@@ -191,8 +228,14 @@ class NDArray {
DGL_DLL static NDArray FromVector( DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0}); const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});
/*!
* \brief Create a std::vector from a 1D NDArray.
* \tparam T Type of vector data.
* \note Type casting is NOT performed. The caller has to make sure that the vector
* type matches the dtype of NDArray.
*/
template<typename T> template<typename T>
static NDArray FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx); std::vector<T> ToVector() const;
/*! /*!
* \brief Function to copy data from one array to another. * \brief Function to copy data from one array to another.
......
...@@ -281,7 +281,7 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True ...@@ -281,7 +281,7 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def hetero_from_relations(rel_graphs): def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
"""Create a heterograph from graphs representing connections of each relation. """Create a heterograph from graphs representing connections of each relation.
The input is a list of heterographs where the ``i``th graph contains edges of type The input is a list of heterographs where the ``i``th graph contains edges of type
...@@ -294,6 +294,9 @@ def hetero_from_relations(rel_graphs): ...@@ -294,6 +294,9 @@ def hetero_from_relations(rel_graphs):
---------- ----------
rel_graphs : list of DGLHeteroGraph rel_graphs : list of DGLHeteroGraph
Each element corresponds to a heterograph for one (src, edge, dst) relation. Each element corresponds to a heterograph for one (src, edge, dst) relation.
num_nodes_per_type : dict[str, Tensor], optional
Number of nodes per node type. If not given, DGL will infer the number of nodes
from the given relation graphs.
Returns Returns
------- -------
...@@ -349,32 +352,35 @@ def hetero_from_relations(rel_graphs): ...@@ -349,32 +352,35 @@ def hetero_from_relations(rel_graphs):
# TODO(minjie): this API can be generalized as a union operation of the input graphs # TODO(minjie): this API can be generalized as a union operation of the input graphs
# TODO(minjie): handle node/edge data # TODO(minjie): handle node/edge data
# infer meta graph # infer meta graph
ntype_set = set() meta_edges_src, meta_edges_dst = [], []
meta_edges = []
ntypes = [] ntypes = []
etypes = [] etypes = []
# TODO(BarclayII): I'm keeping the node type names sorted because even if # TODO(BarclayII): I'm keeping the node type names sorted because even if
# the metagraph is the same, the same node type name in different graphs may # the metagraph is the same, the same node type name in different graphs may
# map to different node type IDs. # map to different node type IDs.
# In the future, we need to lower the type names into C++. # In the future, we need to lower the type names into C++.
if num_nodes_per_type is None:
ntype_set = set()
for rgrh in rel_graphs: for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1 assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0] stype, etype, dtype = rgrh.canonical_etypes[0]
ntype_set.add(stype) ntype_set.add(stype)
ntype_set.add(dtype) ntype_set.add(dtype)
ntypes = list(sorted(ntype_set)) ntypes = list(sorted(ntype_set))
else:
ntypes = list(sorted(num_nodes_per_type.keys()))
num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes])
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)} ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
for rgrh in rel_graphs: for rgrh in rel_graphs:
stype, etype, dtype = rgrh.canonical_etypes[0] stype, etype, dtype = rgrh.canonical_etypes[0]
stid = ntype_dict[stype] meta_edges_src.append(ntype_dict[stype])
dtid = ntype_dict[dtype] meta_edges_dst.append(ntype_dict[dtype])
meta_edges.append((stid, dtid))
etypes.append(etype) etypes.append(etype)
metagraph = graph_index.from_edge_list(meta_edges, True, True) metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True, True)
# create graph index # create graph index
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs]) metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
retg = DGLHeteroGraph(hgidx, ntypes, etypes) retg = DGLHeteroGraph(hgidx, ntypes, etypes)
for i, rgrh in enumerate(rel_graphs): for i, rgrh in enumerate(rel_graphs):
for ntype in rgrh.ntypes: for ntype in rgrh.ntypes:
...@@ -441,8 +447,12 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -441,8 +447,12 @@ def heterograph(data_dict, num_nodes_dict=None):
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0}) nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc ndst = data.number_of_nodes() - nsrc
elif isinstance(data, DGLHeteroGraph): elif isinstance(data, DGLHeteroGraph):
# Do nothing; handled in the next loop # original node type and edge type of ``data`` is ignored.
continue assert len(data.canonical_etypes) == 1, \
"Relational graphs must have only one edge type."
srctype, _, dsttype = data.canonical_etypes[0]
nsrc = data.number_of_nodes(srctype)
ndst = data.number_of_nodes(dsttype)
else: else:
raise DGLError('Unsupported graph data type %s for %s' % ( raise DGLError('Unsupported graph data type %s for %s' % (
type(data), (srctype, etype, dsttype))) type(data), (srctype, etype, dsttype)))
...@@ -464,7 +474,7 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -464,7 +474,7 @@ def heterograph(data_dict, num_nodes_dict=None):
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False)) card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))
return hetero_from_relations(rel_graphs) return hetero_from_relations(rel_graphs, num_nodes_dict)
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None): def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given homogeneous graph to a heterogeneous graph. """Convert the given homogeneous graph to a heterogeneous graph.
...@@ -622,7 +632,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -622,7 +632,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
card=(ntype_count[stid], ntype_count[dtid]), validate=False) card=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph) rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs) hg = hetero_from_relations(
rel_graphs, {ntype: count for ntype, count in zip(ntypes, ntype_count)})
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)} ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes): for ntid, ntype in enumerate(hg.ntypes):
......
...@@ -1684,6 +1684,7 @@ class DGLHeteroGraph(object): ...@@ -1684,6 +1684,7 @@ class DGLHeteroGraph(object):
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes] node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = [] edge_frames = []
num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)} ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid') srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
for i in range(len(self._etypes)): for i in range(len(self._etypes)):
...@@ -1697,7 +1698,8 @@ class DGLHeteroGraph(object): ...@@ -1697,7 +1698,8 @@ class DGLHeteroGraph(object):
edge_frames.append(self._edge_frames[i]) edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True, True) metagraph = graph_index.from_edge_list(meta_edges, True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs) hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
return hg return hg
...@@ -1767,9 +1769,11 @@ class DGLHeteroGraph(object): ...@@ -1767,9 +1769,11 @@ class DGLHeteroGraph(object):
edge_frames = [self._edge_frames[i] for i in etype_ids] edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in ntypes_invmap] induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True) metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs) hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg return hg
......
...@@ -1008,7 +1008,7 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg ...@@ -1008,7 +1008,7 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor(), indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor(),
restrict_format) restrict_format)
def create_heterograph_from_relations(metagraph, rel_graphs): def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type):
"""Create a heterograph from metagraph and graphs of every relation. """Create a heterograph from metagraph and graphs of every relation.
Parameters Parameters
...@@ -1017,12 +1017,18 @@ def create_heterograph_from_relations(metagraph, rel_graphs): ...@@ -1017,12 +1017,18 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
Meta-graph. Meta-graph.
rel_graphs : list of HeteroGraphIndex rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation. Bipartite graph of each relation.
num_nodes_per_type : utils.Index, optional
Number of nodes per node type
Returns Returns
------- -------
HeteroGraphIndex HeteroGraphIndex
""" """
if num_nodes_per_type is None:
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs) return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
else:
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor())
def disjoint_union(metagraph, graphs): def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs. """Return a disjoint union of the input heterographs.
...@@ -1085,6 +1091,17 @@ class HeteroPickleStates(ObjectBase): ...@@ -1085,6 +1091,17 @@ class HeteroPickleStates(ObjectBase):
""" """
return _CAPI_DGLHeteroPickleStatesGetMetagraph(self) return _CAPI_DGLHeteroPickleStatesGetMetagraph(self)
@property
def num_nodes_per_type(self):
"""Number of nodes per edge type
Returns
-------
Tensor
Array of number of nodes for each type
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLHeteroPickleStatesGetNumVertices(self))
@property @property
def adjs(self): def adjs(self):
"""Adjacency matrices of all the relation graphs """Adjacency matrices of all the relation graphs
...@@ -1097,11 +1114,12 @@ class HeteroPickleStates(ObjectBase): ...@@ -1097,11 +1114,12 @@ class HeteroPickleStates(ObjectBase):
return list(_CAPI_DGLHeteroPickleStatesGetAdjs(self)) return list(_CAPI_DGLHeteroPickleStatesGetAdjs(self))
def __getstate__(self): def __getstate__(self):
return self.metagraph, self.adjs return self.metagraph, self.num_nodes_per_type, self.adjs
def __setstate__(self, state): def __setstate__(self, state):
metagraph, adjs = state metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, metagraph, adjs) _CAPI_DGLCreateHeteroPickleStates, metagraph, num_nodes_per_type, adjs)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
...@@ -84,6 +84,10 @@ class IdHashMap { ...@@ -84,6 +84,10 @@ class IdHashMap {
return values; return values;
} }
inline size_t Size() const {
return oldv2newv_.size();
}
private: private:
static constexpr int32_t kFilterMask = 0xFFFFFF; static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1; static constexpr int32_t kFilterSize = kFilterMask + 1;
......
...@@ -186,7 +186,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -186,7 +186,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
} }
} }
} }
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx); return NDArray::FromVector(ret_vec, csr.data->ctx);
} }
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t); template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
...@@ -228,7 +228,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -228,7 +228,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
} }
} }
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx); return NDArray::FromVector(ret_vec, csr.data->ctx);
} }
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols); template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
...@@ -306,9 +306,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c ...@@ -306,9 +306,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
} }
} }
return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx), return {NDArray::FromVector(ret_rows, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx), NDArray::FromVector(ret_cols, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)}; NDArray::FromVector(ret_data, csr.data->ctx)};
} }
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>( template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
...@@ -536,8 +536,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -536,8 +536,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
IdType* ptr = static_cast<IdType*>(sub_data_arr->data); IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr); std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols, return CSRMatrix{new_nrows, new_ncols,
NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx), NDArray::FromVector(sub_indptr, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx), NDArray::FromVector(sub_indices, csr.indptr->ctx),
sub_data_arr}; sub_data_arr};
} }
......
...@@ -16,16 +16,6 @@ ...@@ -16,16 +16,6 @@
using dgl::runtime::operator<<; using dgl::runtime::operator<<;
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
/*! \brief Output the string representation of device context.*/ /*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) { inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << ctx.device_type << ":" << ctx.device_id; return os << ctx.device_type << ":" << ctx.device_id;
......
...@@ -10,8 +10,10 @@ namespace dgl { ...@@ -10,8 +10,10 @@ namespace dgl {
// creator implementation // creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { GraphPtr meta_graph,
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs)); const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
} }
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
......
...@@ -37,7 +37,8 @@ HeteroSubgraph EdgeSubgraphPreserveNodes( ...@@ -37,7 +37,8 @@ HeteroSubgraph EdgeSubgraphPreserveNodes(
ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0]; ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1]; ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
} }
ret.graph = HeteroGraphPtr(new HeteroGraph(hg->meta_graph(), subrels)); ret.graph = HeteroGraphPtr(new HeteroGraph(
hg->meta_graph(), subrels, hg->NumVerticesPerType()));
return ret; return ret;
} }
...@@ -86,8 +87,10 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -86,8 +87,10 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
subedges[etype] = earray; subedges[etype] = earray;
} }
// step (3) // step (3)
std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) { for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]); ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
} }
// step (4) // step (4)
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
...@@ -102,14 +105,12 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -102,14 +105,12 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
subedges[etype].src, subedges[etype].src,
subedges[etype].dst); subedges[etype].dst);
} }
ret.graph = HeteroGraphPtr(new HeteroGraph(hg->meta_graph(), subrels)); ret.graph = HeteroGraphPtr(new HeteroGraph(
hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
return ret; return ret;
} }
} // namespace void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs)
: BaseHeteroGraph(meta_graph) {
// Sanity check // Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
...@@ -117,8 +118,12 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -117,8 +118,12 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
for (const auto &rg : rel_graphs) { for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
} }
}
std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// create num verts per type // create num verts per type
num_verts_per_type_.resize(meta_graph->NumVertices(), -1); std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
EdgeArray etype_array = meta_graph->Edges(); EdgeArray etype_array = meta_graph->Edges();
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data); dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
...@@ -136,30 +141,49 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -136,30 +141,49 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
// # nodes of source type // # nodes of source type
nv = rg->NumVertices(sty); nv = rg->NumVertices(sty);
if (num_verts_per_type_[srctype] < 0) if (num_verts_per_type[srctype] < 0)
num_verts_per_type_[srctype] = nv; num_verts_per_type[srctype] = nv;
else else
CHECK_EQ(num_verts_per_type_[srctype], nv) CHECK_EQ(num_verts_per_type[srctype], nv)
<< "Mismatch number of vertices for vertex type " << srctype; << "Mismatch number of vertices for vertex type " << srctype;
// # nodes of destination type // # nodes of destination type
nv = rg->NumVertices(dty); nv = rg->NumVertices(dty);
if (num_verts_per_type_[dsttype] < 0) if (num_verts_per_type[dsttype] < 0)
num_verts_per_type_[dsttype] = nv; num_verts_per_type[dsttype] = nv;
else else
CHECK_EQ(num_verts_per_type_[dsttype], nv) CHECK_EQ(num_verts_per_type[dsttype], nv)
<< "Mismatch number of vertices for vertex type " << dsttype; << "Mismatch number of vertices for vertex type " << dsttype;
} }
return num_verts_per_type;
}
relation_graphs_.resize(rel_graphs.size()); std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
for (size_t i = 0; i < rel_graphs.size(); ++i) { for (size_t i = 0; i < rel_graphs.size(); ++i) {
HeteroGraphPtr relg = rel_graphs[i]; HeteroGraphPtr relg = rel_graphs[i];
if (std::dynamic_pointer_cast<UnitGraph>(relg)) { if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
relation_graphs_[i] = std::dynamic_pointer_cast<UnitGraph>(relg); relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
} else { } else {
relation_graphs_[i] = CHECK_NOTNULL( relation_graphs[i] = CHECK_NOTNULL(
std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0))); std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
} }
} }
return relation_graphs;
}
} // namespace
HeteroGraph::HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
if (num_nodes_per_type.size() == 0)
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else
num_verts_per_type_ = num_nodes_per_type;
HeteroGraphSanityCheck(meta_graph, rel_graphs);
relation_graphs_ = CastToUnitGraphs(rel_graphs);
} }
bool HeteroGraph::IsMultigraph() const { bool HeteroGraph::IsMultigraph() const {
...@@ -183,6 +207,9 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con ...@@ -183,6 +207,9 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of vertex types.";
HeteroSubgraph ret; HeteroSubgraph ret;
ret.induced_vertices = vids; ret.induced_vertices = vids;
std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
num_vertices_per_type[vtype] = vids[vtype]->shape[0];
ret.induced_edges.resize(NumEdgeTypes()); ret.induced_edges.resize(NumEdgeTypes());
std::vector<HeteroGraphPtr> subrels(NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
...@@ -196,7 +223,8 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con ...@@ -196,7 +223,8 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
subrels[etype] = rel_vsg.graph; subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0]; ret.induced_edges[etype] = rel_vsg.induced_edges[0];
} }
ret.graph = HeteroGraphPtr(new HeteroGraph(meta_graph_, subrels)); ret.graph = HeteroGraphPtr(new HeteroGraph(
meta_graph_, subrels, std::move(num_vertices_per_type)));
return ret; return ret;
} }
......
...@@ -19,7 +19,10 @@ namespace dgl { ...@@ -19,7 +19,10 @@ namespace dgl {
/*! \brief Heterograph */ /*! \brief Heterograph */
class HeteroGraph : public BaseHeteroGraph { class HeteroGraph : public BaseHeteroGraph {
public: public:
HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs); HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {});
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype; CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
...@@ -65,6 +68,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -65,6 +68,10 @@ class HeteroGraph : public BaseHeteroGraph {
return num_verts_per_type_[vtype]; return num_verts_per_type_[vtype];
} }
inline std::vector<int64_t> NumVerticesPerType() const override {
return num_verts_per_type_;
}
uint64_t NumEdges(dgl_type_t etype) const override { uint64_t NumEdges(dgl_type_t etype) const override {
return GetRelationGraph(etype)->NumEdges(0); return GetRelationGraph(etype)->NumEdges(0);
} }
......
...@@ -55,6 +55,21 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph") ...@@ -55,6 +55,21 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
IdArray num_nodes_per_type = args[2];
std::vector<HeteroGraphPtr> rel_ptrs;
rel_ptrs.reserve(rel_graphs.size());
for (const auto& ref : rel_graphs) {
rel_ptrs.push_back(ref.sptr());
}
auto hgptr = CreateHeteroGraph(
meta_graph.sptr(), rel_ptrs, num_nodes_per_type.ToVector<int64_t>());
*rv = HeteroGraphRef(hgptr);
});
///////////////////////// HeteroGraph member functions ///////////////////////// ///////////////////////// HeteroGraph member functions /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
......
...@@ -15,6 +15,7 @@ namespace dgl { ...@@ -15,6 +15,7 @@ namespace dgl {
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates states; HeteroPickleStates states;
states.metagraph = graph->meta_graph(); states.metagraph = graph->meta_graph();
states.num_nodes_per_type = graph->NumVerticesPerType();
states.adjs.resize(graph->NumEdgeTypes()); states.adjs.resize(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::ANY); SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::ANY);
...@@ -36,6 +37,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -36,6 +37,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const auto metagraph = states.metagraph; const auto metagraph = states.metagraph;
const auto &num_nodes_per_type = states.num_nodes_per_type;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges()); CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
...@@ -58,7 +60,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -58,7 +60,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
LOG(FATAL) << "Unsupported sparse format."; LOG(FATAL) << "Unsupported sparse format.";
} }
} }
return CreateHeteroGraph(metagraph, relgraphs); return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
} }
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
...@@ -67,6 +69,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph") ...@@ -67,6 +69,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
*rv = GraphRef(st->metagraph); *rv = GraphRef(st->metagraph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
*rv = NDArray::FromVector(st->num_nodes_per_type);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
...@@ -77,9 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs") ...@@ -77,9 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
List<SparseMatrixRef> adjs = args[1]; IdArray num_nodes_per_type = args[1];
List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->metagraph = metagraph.sptr(); st->metagraph = metagraph.sptr();
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
st->adjs.reserve(adjs.size()); st->adjs.reserve(adjs.size());
for (const auto& ref : adjs) for (const auto& ref : adjs)
st->adjs.push_back(ref.sptr()); st->adjs.push_back(ref.sptr());
......
...@@ -86,7 +86,7 @@ HeteroSubgraph SampleNeighbors( ...@@ -86,7 +86,7 @@ HeteroSubgraph SampleNeighbors(
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels); ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
...@@ -160,7 +160,7 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -160,7 +160,7 @@ HeteroSubgraph SampleNeighborsTopk(
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels); ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
......
...@@ -38,7 +38,7 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray ...@@ -38,7 +38,7 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels); ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
...@@ -73,7 +73,7 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra ...@@ -73,7 +73,7 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels); ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
......
...@@ -32,7 +32,8 @@ CompactGraphs( ...@@ -32,7 +32,8 @@ CompactGraphs(
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
// TODO(BarclayII): check whether the node space and metagraph of each graph is the same. // TODO(BarclayII): check whether the node space and metagraph of each graph is the same.
// Step 1: Collect the nodes that has connections for each type. // Step 1: Collect the nodes that has connections for each type.
std::vector<aten::IdHashMap<IdType>> hashmaps(graphs[0]->NumVertexTypes()); const int64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype]
for (size_t i = 0; i < always_preserve.size(); ++i) for (size_t i = 0; i < always_preserve.size(); ++i)
...@@ -56,9 +57,12 @@ CompactGraphs( ...@@ -56,9 +57,12 @@ CompactGraphs(
} }
// Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping. // Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping.
std::vector<IdArray> induced_nodes; std::vector<IdArray> induced_nodes(num_ntypes);
for (auto &hashmap : hashmaps) std::vector<int64_t> num_induced_nodes(num_ntypes);
induced_nodes.push_back(hashmap.Values()); for (int64_t i = 0; i < num_ntypes; ++i) {
induced_nodes[i] = hashmaps[i].Values();
num_induced_nodes[i] = hashmaps[i].Size();
}
// Step 3: Remap the edges of each graph. // Step 3: Remap the edges of each graph.
std::vector<HeteroGraphPtr> new_graphs; std::vector<HeteroGraphPtr> new_graphs;
...@@ -84,7 +88,7 @@ CompactGraphs( ...@@ -84,7 +88,7 @@ CompactGraphs(
mapped_cols)); mapped_cols));
} }
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs)); new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
} }
return std::make_pair(new_graphs, induced_nodes); return std::make_pair(new_graphs, induced_nodes);
......
...@@ -71,7 +71,8 @@ ToSimpleGraph(const HeteroGraphPtr graph) { ...@@ -71,7 +71,8 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
coalesced_adj.col); coalesced_adj.col);
} }
const HeteroGraphPtr result = CreateHeteroGraph(metagraph, rel_graphs); const HeteroGraphPtr result = CreateHeteroGraph(
metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps); return std::make_tuple(result, counts, edge_maps);
} }
......
...@@ -12,6 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -12,6 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty"; CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices());
// Loop over all canonical etypes // Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
...@@ -46,8 +47,10 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -46,8 +47,10 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
aten::VecToIdArray(result_src), aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst)); aten::VecToIdArray(result_dst));
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
} }
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs)); return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
...@@ -121,8 +124,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -121,8 +124,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
} }
std::vector<HeteroGraphPtr> rst; std::vector<HeteroGraphPtr> rst;
std::vector<int64_t> num_nodes_per_type(num_vertex_types);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
rst.push_back(HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs[g]))); for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
} }
return rst; return rst;
} }
......
...@@ -92,6 +92,11 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -92,6 +92,11 @@ class UnitGraph : public BaseHeteroGraph {
uint64_t NumVertices(dgl_type_t vtype) const override; uint64_t NumVertices(dgl_type_t vtype) const override;
inline std::vector<int64_t> NumVerticesPerType() const override {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on unit graphs.";
return {};
}
uint64_t NumEdges(dgl_type_t etype) const override; uint64_t NumEdges(dgl_type_t etype) const override;
bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override; bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;
......
...@@ -224,7 +224,8 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -224,7 +224,8 @@ void NDArray::CopyFromTo(DLTensor* from,
} }
template<typename T> template<typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx) { NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) {
const DLDataType dtype = DLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size()); int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, DLContext{kDLCPU, 0}); NDArray ret = NDArray::Empty({size}, dtype, DLContext{kDLCPU, 0});
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
...@@ -241,27 +242,41 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLDataType dtype, DLConte ...@@ -241,27 +242,41 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLDataType dtype, DLConte
} }
// export specializations // export specializations
template NDArray NDArray::FromVector(const std::vector<int32_t>&, DLDataType, DLContext); template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DLContext);
template NDArray NDArray::FromVector(const std::vector<int64_t>&, DLDataType, DLContext); template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DLContext);
template NDArray NDArray::FromVector(const std::vector<uint32_t>&, DLDataType, DLContext); template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DLContext);
template NDArray NDArray::FromVector(const std::vector<uint64_t>&, DLDataType, DLContext); template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DLContext);
template NDArray NDArray::FromVector(const std::vector<float>&, DLDataType, DLContext); template NDArray NDArray::FromVector<float>(const std::vector<float>&, DLContext);
template NDArray NDArray::FromVector(const std::vector<double>&, DLDataType, DLContext); template NDArray NDArray::FromVector<double>(const std::vector<double>&, DLContext);
// specializations of FromVector template<typename T>
#define GEN_FROMVECTOR_FOR(T, DTypeCode, DTypeBits) \ std::vector<T> NDArray::ToVector() const {
template<> \ const DLDataType dtype = DLDataTypeTraits<T>::dtype;
NDArray NDArray::FromVector<T>(const std::vector<T> &vec, DLContext ctx) { \ CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays";
return FromVector<T>(vec, DLDataType{DTypeCode, DTypeBits, 1}, ctx); \ CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
}
GEN_FROMVECTOR_FOR(int32_t, kDLInt, 32); int64_t size = data_->dl_tensor.shape[0];
GEN_FROMVECTOR_FOR(int64_t, kDLInt, 64); std::vector<T> vec(size);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just const DLContext &ctx = data_->dl_tensor.ctx;
// converting uints to signed NDArrays. DeviceAPI::Get(ctx)->CopyDataFromTo(
GEN_FROMVECTOR_FOR(uint32_t, kDLInt, 32); static_cast<T*>(data_->dl_tensor.data),
GEN_FROMVECTOR_FOR(uint64_t, kDLInt, 64); 0,
GEN_FROMVECTOR_FOR(float, kDLFloat, 32); vec.data(),
GEN_FROMVECTOR_FOR(double, kDLFloat, 64); 0,
size * sizeof(T),
ctx,
DLContext{kDLCPU, 0},
dtype,
nullptr);
return vec;
}
template std::vector<int32_t> NDArray::ToVector<int32_t>() const;
template std::vector<int64_t> NDArray::ToVector<int64_t>() const;
template std::vector<uint32_t> NDArray::ToVector<uint32_t>() const;
template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const;
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment