Unverified Commit bbfff8ce authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Casting between DGLGraph and DGLHeteroGraph (#1391)



* [Feature] Casting between DGLGraph and DGLHeteroGraph

* lint

* address comments
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent ebda932d
...@@ -438,6 +438,12 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -438,6 +438,12 @@ class BaseHeteroGraph : public runtime::Object {
return nullptr; return nullptr;
} }
/*! \brief Cast this graph to immutable graph */
virtual GraphPtr AsImmutableGraph() const {
LOG(FATAL) << "AsImmutableGraph not supported.";
return nullptr;
}
static constexpr const char* _type_key = "graph.HeteroGraph"; static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "runtime/ndarray.h" #include "runtime/ndarray.h"
#include "graph_interface.h" #include "graph_interface.h"
#include "lazy.h" #include "lazy.h"
#include "base_heterograph.h"
namespace dgl { namespace dgl {
...@@ -975,8 +976,12 @@ class ImmutableGraph: public GraphInterface { ...@@ -975,8 +976,12 @@ class ImmutableGraph: public GraphInterface {
GetOutCSR()->SortCSR(); GetOutCSR()->SortCSR();
} }
/*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const;
protected: protected:
friend class Serializer; friend class Serializer;
friend class UnitGraph;
/* !\brief internal default constructor */ /* !\brief internal default constructor */
ImmutableGraph() {} ImmutableGraph() {}
......
...@@ -37,7 +37,9 @@ __all__ = [ ...@@ -37,7 +37,9 @@ __all__ = [
'to_simple', 'to_simple',
'in_subgraph', 'in_subgraph',
'out_subgraph', 'out_subgraph',
'remove_edges'] 'remove_edges',
'as_immutable_graph',
'as_heterograph']
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
...@@ -1084,4 +1086,50 @@ def to_simple(g, return_counts='count', writeback_mapping=None): ...@@ -1084,4 +1086,50 @@ def to_simple(g, return_counts='count', writeback_mapping=None):
return simple_graph return simple_graph
def as_heterograph(g, ntype='_U', etype='_E'):
"""Convert a DGLGraph to a DGLHeteroGraph with one node and edge type.
Node and edge features are preserved.
Parameters
----------
g : DGLGraph
The graph
ntype : str, optional
The node type name
etype : str, optional
The edge type name
Returns
-------
DGLHeteroGraph
The heterograph.
"""
hgi = _CAPI_DGLAsHeteroGraph(g._graph)
hg = DGLHeteroGraph(hgi, [ntype], [etype])
hg.ndata.update(g.ndata)
hg.edata.update(g.edata)
return hg
def as_immutable_graph(hg):
"""Convert a DGLHeteroGraph with one node and edge type into a DGLGraph.
Node and edge features are preserved.
Parameters
----------
g : DGLHeteroGraph
The heterograph
Returns
-------
DGLGraph
The graph.
"""
gidx = _CAPI_DGLAsImmutableGraph(hg._graph)
g = DGLGraph(gidx)
g.ndata.update(hg.ndata)
g.edata.update(hg.edata)
return g
_init_api("dgl.transform") _init_api("dgl.transform")
...@@ -358,4 +358,12 @@ void HeteroGraph::Save(dmlc::Stream* fs) const { ...@@ -358,4 +358,12 @@ void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(num_verts_per_type_); fs->Write(num_verts_per_type_);
} }
GraphPtr HeteroGraph::AsImmutableGraph() const {
CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
auto unit_graph = CHECK_NOTNULL(
std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
return unit_graph->AsImmutableGraph();
}
} // namespace dgl } // namespace dgl
...@@ -194,6 +194,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -194,6 +194,8 @@ class HeteroGraph : public BaseHeteroGraph {
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override; FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
GraphPtr AsImmutableGraph() const override;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/ /*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs); bool Load(dmlc::Stream* fs);
......
...@@ -470,4 +470,10 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLOutSubgraph") ...@@ -470,4 +470,10 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLOutSubgraph")
*rv = HeteroGraphRef(ret); *rv = HeteroGraphRef(ret);
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->AsImmutableGraph());
});
} // namespace dgl } // namespace dgl
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h> #include <dgl/runtime/smart_ptr_serializer.h>
#include <dgl/base_heterograph.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dmlc/type_traits.h>
#include <string.h> #include <string.h>
...@@ -15,6 +16,8 @@ ...@@ -15,6 +16,8 @@
#include <tuple> #include <tuple>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "heterograph.h"
#include "unit_graph.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -665,6 +668,33 @@ void ImmutableGraph::Save(dmlc::Stream *fs) const { ...@@ -665,6 +668,33 @@ void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(GetOutCSR()); fs->Write(GetOutCSR());
} }
HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
aten::CSRMatrix in_csr, out_csr;
aten::COOMatrix coo;
if (in_csr_)
in_csr = GetInCSR()->ToCSRMatrix();
if (out_csr_)
out_csr = GetOutCSR()->ToCSRMatrix();
if (coo_)
coo = GetCOO()->ToCOOMatrix();
auto g = UnitGraph::CreateHomographFrom(
in_csr, out_csr, coo,
in_csr_ != nullptr,
out_csr_ != nullptr,
coo_ != nullptr);
return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ig) << "graph is not readonly";
*rv = HeteroGraphRef(ig->AsHeteroGraph());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
......
...@@ -1175,6 +1175,30 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c ...@@ -1175,6 +1175,30 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
} }
HeteroGraphPtr UnitGraph::CreateHomographFrom(
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
SparseFormat restrict_format) {
auto mg = CreateUnitGraphMetaGraph1();
CSRPtr in_csr_ptr = nullptr;
CSRPtr out_csr_ptr = nullptr;
COOPtr coo_ptr = nullptr;
if (has_in_csr)
in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
if (has_out_csr)
out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
if (has_coo)
coo_ptr = COOPtr(new COO(mg, coo));
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
}
UnitGraph::CSRPtr UnitGraph::GetInCSR() const { UnitGraph::CSRPtr UnitGraph::GetInCSR() const {
if (!in_csr_) { if (!in_csr_) {
if (out_csr_) { if (out_csr_) {
...@@ -1272,6 +1296,31 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { ...@@ -1272,6 +1296,31 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::kCOO; return SparseFormat::kCOO;
} }
GraphPtr UnitGraph::AsImmutableGraph() const {
CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
dgl::COOPtr coo_ptr = nullptr;
if (in_csr_) {
aten::CSRMatrix csc = GetCSCMatrix(0);
in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data, true));
}
if (out_csr_) {
aten::CSRMatrix csr = GetCSRMatrix(0);
out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data, true));
}
if (coo_) {
aten::COOMatrix coo = GetCOOMatrix(0);
if (!COOHasData(coo)) {
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col, true));
} else {
IdArray new_src = Scatter(coo.row, coo.data);
IdArray new_dst = Scatter(coo.col, coo.data);
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst, true));
}
}
return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127; constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
bool UnitGraph::Load(dmlc::Stream* fs) { bool UnitGraph::Load(dmlc::Stream* fs) {
......
...@@ -229,6 +229,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -229,6 +229,7 @@ class UnitGraph : public BaseHeteroGraph {
private: private:
friend class Serializer; friend class Serializer;
friend class HeteroGraph; friend class HeteroGraph;
friend class ImmutableGraph;
// private empty constructor // private empty constructor
UnitGraph() {} UnitGraph() {}
...@@ -243,6 +244,25 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -243,6 +244,25 @@ class UnitGraph : public BaseHeteroGraph {
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief constructor
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
* \param coo coo
* \param has_in_csr whether in_csr is valid
* \param has_out_csr whether out_csr is valid
* \param has_coo whether coo is valid
*/
static HeteroGraphPtr CreateHomographFrom(
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
SparseFormat restrict_format = SparseFormat::kAny);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
...@@ -268,6 +288,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -268,6 +288,8 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */ /*! \return Whether the graph is hypersparse */
bool IsHypersparse() const; bool IsHypersparse() const;
GraphPtr AsImmutableGraph() const override;
// Graph stored in different format. We use an on-demand strategy: the format is // Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked. // only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */ /*! \brief CSR graph that stores reverse edges */
......
...@@ -552,6 +552,32 @@ def test_remove_edges(): ...@@ -552,6 +552,32 @@ def test_remove_edges():
check(g4, 'AB', g, [3, 1, 2, 0]) check(g4, 'AB', g, [3, 1, 2, 0])
check(g4, 'BA', g, []) check(g4, 'BA', g, [])
def test_cast():
m = spsp.coo_matrix(([1, 1], ([0, 1], [1, 2])), (4, 4))
g = dgl.DGLGraph(m, readonly=True)
gsrc, gdst = g.edges(order='eid')
ndata = F.randn((4, 5))
edata = F.randn((2, 4))
g.ndata['x'] = ndata
g.edata['y'] = edata
hg = dgl.as_heterograph(g, 'A', 'AA')
assert hg.ntypes == ['A']
assert hg.etypes == ['AA']
assert hg.canonical_etypes == [('A', 'AA', 'A')]
assert hg.number_of_nodes() == 4
assert hg.number_of_edges() == 2
hgsrc, hgdst = hg.edges(order='eid')
assert F.array_equal(gsrc, hgsrc)
assert F.array_equal(gdst, hgdst)
g2 = dgl.as_immutable_graph(hg)
assert g2.number_of_nodes() == 4
assert g2.number_of_edges() == 2
g2src, g2dst = hg.edges(order='eid')
assert F.array_equal(g2src, gsrc)
assert F.array_equal(g2dst, gdst)
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
test_no_backtracking() test_no_backtracking()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment