Unverified Commit c2e61ce1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Restrict sparse format for DGLHeteroGraph (#1474)



* upd

* simplify

* further simplify

* lint

* doc

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* lint

* rename format

* upd

* lint

* upd

* upd

* upd

* udp

* debug

* upd

* upd

* upd

* upd

* upd

* 无可厚非吧

* 中三边肥

* 你一定要喊吗

* lint

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* Update unit_graph.h
Co-authored-by: default avatarZihao Ye <yzh119@192.168.0.110>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 1e84168e
......@@ -66,6 +66,26 @@ Querying graph structure
DGLHeteroGraph.out_degree
DGLHeteroGraph.out_degrees
Querying and manipulating sparse format
---------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLHeteroGraph.format_in_use
DGLHeteroGraph.restrict_format
DGLHeteroGraph.to_format
Querying and manipulating index data type
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLHeteroGraph.idtype
DGLHeteroGraph.long
DGLHeteroGraph.int
Using Node/edge features
------------------------
......
......@@ -24,6 +24,15 @@ namespace dgl {
typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
* - z indicates whether coo is in use.
*/
typedef uint8_t dgl_format_code_t;
using dgl::runtime::NDArray;
......@@ -41,7 +50,8 @@ enum class SparseFormat {
kAny = 0,
kCOO = 1,
kCSR = 2,
kCSC = 3
kCSC = 3,
kAuto = 4 // kAuto is a placeholder that indicates it would be materialized later.
};
// Parse sparse format from string.
......@@ -52,10 +62,29 @@ inline SparseFormat ParseSparseFormat(const std::string& name) {
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::kCSC;
else if (name == "any")
return SparseFormat::kAny;
else if (name == "auto")
return SparseFormat::kAuto;
else
LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kAny;
}
// Create string from sparse format.
inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return std::string("coo");
else if (sparse_format == SparseFormat::kCSR)
return std::string("csr");
else if (sparse_format == SparseFormat::kCSC)
return std::string("csc");
else if (sparse_format == SparseFormat::kAny)
return std::string("any");
else
return std::string("auto");
}
// Sparse matrix object that is exposed to python API.
struct SparseMatrix : public runtime::Object {
// Sparse format.
......@@ -660,7 +689,7 @@ bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
* \brief Batched implementation of COOIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray COOIsNonZero(COOMatrix, runtime::NDArray row, runtime::NDArray col);
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row);
......
......@@ -368,6 +368,27 @@ class BaseHeteroGraph : public runtime::Object {
*/
virtual SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const = 0;
/*!
* \brief Get restrict sparse format of the graph.
*
* \return a string representing the sparse format: 'coo'/'csr'/'csc'/'any'
*/
virtual std::string GetRestrictFormat() const = 0;
/*!
* \brief Return the sparse format in use for the graph.
*
* \return a number of type dgl_format_code_t.
*/
virtual dgl_format_code_t GetFormatInUse() const = 0;
/*!
* \brief Return the graph in specified restrict format.
*
* \return The new graph.
*/
virtual HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const = 0;
/*!
* \brief Get adjacency matrix in COO format.
* \param etype Edge type.
......
......@@ -939,7 +939,7 @@ class ImmutableGraph: public GraphInterface {
/*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const;
void SortCSR() {
void SortCSR() override {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
}
......
......@@ -122,7 +122,6 @@ class ObjectRef {
* Compare with the two are referencing to the same object (compare by address).
*
* \param other Another object ref.
* \param other Another object ref.
* \return the compare result.
*/
inline bool same_as(const ObjectRef& other) const;
......
......@@ -22,7 +22,7 @@ __all__ = [
]
def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True,
restrict_format='any', index_dtype="int64", **kwargs):
restrict_format='auto', index_dtype="int64", **kwargs):
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
......@@ -53,8 +53,10 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
If True, check if node ids are within cardinality, the check process may take
some time. (Default: True)
If False and card is not None, user would receive a warning.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -150,7 +152,7 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None,
validate=True, restrict_format='any', index_dtype='int64', **kwargs):
validate=True, restrict_format='auto', index_dtype='int64', **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
......@@ -187,8 +189,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
If True, check if node ids are within cardinality, the check process may take
some time. (Default: True)
If False and card is not None, user would receive a warning.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -292,8 +296,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
return create_from_scipy(
data, utype, etype, vtype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite(data, utype, etype,
vtype, restrict_format=restrict_format,
return create_from_networkx_bipartite(data, utype, etype, vtype,
restrict_format=restrict_format,
index_dtype=index_dtype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))
......@@ -410,7 +414,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg
def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
def heterograph(data_dict, num_nodes_dict=None, restrict_format='auto', index_dtype='int64'):
"""Create a heterogeneous graph from a dictionary between edge types and edge lists.
Parameters
......@@ -428,6 +432,11 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
By default DGL infers the number of nodes for each node type from ``data_dict``
by taking the maximum node ID plus one for each node type.
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......@@ -494,12 +503,17 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
elif srctype == dsttype:
rel_graphs.append(graph(
data, srctype, etype,
num_nodes=num_nodes_dict[srctype], validate=False, index_dtype=index_dtype))
num_nodes=num_nodes_dict[srctype],
validate=False,
restrict_format=restrict_format,
index_dtype=index_dtype))
else:
rel_graphs.append(bipartite(
data, srctype, etype, dsttype,
num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]),
validate=False, index_dtype=index_dtype))
validate=False,
restrict_format=restrict_format,
index_dtype=index_dtype))
return hetero_from_relations(rel_graphs, num_nodes_dict)
......@@ -772,7 +786,7 @@ def to_homo(G):
############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True,
restrict_format="any", index_dtype='int64'):
restrict_format="auto", index_dtype='int64'):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
......@@ -797,8 +811,10 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......@@ -835,7 +851,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate=True, restrict_format='any', index_dtype='int64'):
validate=True, restrict_format='auto', index_dtype='int64'):
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
......@@ -858,8 +874,10 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......@@ -875,7 +893,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate, restrict_format, index_dtype=index_dtype)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
restrict_format='any', index_dtype='int64'):
restrict_format='auto', index_dtype='int64'):
"""Internal function to create a heterograph from a scipy sparse matrix with types.
Parameters
......@@ -897,8 +915,10 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
are always assumed to be ordered by edge ID already.
validate : bool, optional
If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......@@ -929,7 +949,8 @@ def create_from_networkx(nx_graph,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None,
restrict_format='any', index_dtype='int64'):
restrict_format='auto',
index_dtype='int64'):
"""Create a heterograph that has only one set of nodes and edges.
Parameters
......@@ -946,8 +967,10 @@ def create_from_networkx(nx_graph,
Names for node features to retrieve from the NetworkX graph (Default: None)
edge_attrs : list of str
Names for edge features to retrieve from the NetworkX graph (Default: None)
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......@@ -1035,7 +1058,8 @@ def create_from_networkx_bipartite(nx_graph,
edge_id_attr_name='id',
node_attrs=None,
edge_attrs=None,
restrict_format='any', index_dtype='int64'):
restrict_format='auto',
index_dtype='int64'):
"""Create a heterograph that has one set of source nodes, one set of
destination nodes and one set of edges.
......@@ -1059,8 +1083,10 @@ def create_from_networkx_bipartite(nx_graph,
Names for node features to retrieve from the NetworkX graph (Default: None)
edge_attrs : list of str
Names for edge features to retrieve from the NetworkX graph (Default: None)
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto' optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns
-------
......
......@@ -1099,6 +1099,11 @@ class DGLHeteroGraph(object):
-------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
See Also
--------
long
int
"""
return getattr(F, self._graph.dtype)
......@@ -4150,6 +4155,60 @@ class DGLHeteroGraph(object):
"""Return if the graph is homogeneous."""
return len(self.ntypes) == 1 and len(self.etypes) == 1
def format_in_use(self, etype=None, return_all=False):
"""Return the sparse formats in use of the given edge/relation type.
Returns
-------
list of string
Return all the formats currently in use (could be multiple).
See Also
--------
restrict_format
to_format
"""
return self._graph.format_in_use(self.get_etype_id(etype))
def restrict_format(self, etype=None):
"""Return the allowed sparse formats of the given edge/relation type.
Returns
-------
string : 'any', 'coo', 'csr', or 'csc'
'any' indicates all sparse formats are allowed in .
See Also
--------
format_in_use
to_format
"""
return self._graph.restrict_format(self.get_etype_id(etype))
def to_format(self, restrict_format):
"""Return a cloned graph but stored in the given restrict format.
If 'any' is given, the restrict formats of the returned graph is relaxed.
The returned graph share the same node/edge data of the original graph.
Parameters
----------
restrict_format : string
Desired restrict format ('any', 'coo', 'csr', 'csc').
Returns
-------
A new graph.
See Also
--------
format_in_use
restrict_format
"""
return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
def long(self):
"""Return a heterograph object use int64 as index dtype,
with the ndata and edata as the original object
......@@ -4169,6 +4228,7 @@ class DGLHeteroGraph(object):
See Also
--------
int
idtype
"""
return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes,
self._node_frames,
......@@ -4193,6 +4253,7 @@ class DGLHeteroGraph(object):
See Also
--------
long
idtype
"""
return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes,
self._node_frames,
......
......@@ -901,6 +901,62 @@ class HeteroGraphIndex(ObjectBase):
rev_order = rev_csr(2)
return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype)
def format_in_use(self, etype):
"""Return the sparse formats in use of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
list of string : return all the formats currently in use (could be multiple).
"""
format_code = _CAPI_DGLHeteroGetFormatInUse(self, etype)
ret = []
if format_code & 1:
ret.append('coo')
format_code >>= 1
if format_code & 1:
ret.append('csr')
format_code >>= 1
if format_code & 1:
ret.append('csc')
return ret
def restrict_format(self, etype):
"""Return restrict sparse format of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
string : 'any', 'coo', 'csr', or 'csc'
"""
ret = _CAPI_DGLHeteroGetRestrictFormat(self, etype)
return ret
def to_format(self, restrict_format):
"""Return a clone graph index but stored in the given sparse format.
If 'any' is given, the restrict formats of the returned graph index
is relaxed.
Parameters
----------
restrict_format : string
Desired restrict format ('any', 'coo', 'csr', 'csc').
Returns
-------
A new graph index.
"""
return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
......@@ -942,6 +998,7 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret]
#################################################################
# Creators
#################################################################
......
......@@ -113,7 +113,7 @@ class SAGEConv(nn.Block):
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # saame as above if homogeneous
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in degrees
degs = graph.in_degrees().astype(feat_dst.dtype)
......
......@@ -246,6 +246,30 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::GetGraphInFormat(SparseFormat restrict_format) const {
std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
format_rels[etype] = relgraph->GetGraphInFormat(restrict_format);
}
return HeteroGraphPtr(new HeteroGraph(
meta_graph_, format_rels, NumVerticesPerType()));
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
const std::vector<dgl_type_t>& etypes) const {
const int64_t bits = NumBits();
......
......@@ -187,11 +187,23 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->SelectFormat(0, preferred_format);
}
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
GraphPtr AsImmutableGraph() const override;
......@@ -205,6 +217,9 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
private:
// To create empty class
friend class Serializer;
......
......@@ -85,12 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
// Test if the heterograph is a unit graph. If so, return itself.
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg.sptr());
if (bg != nullptr)
*rv = bg;
else
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
auto unit_graph = hg->GetRelationGraph(etype);
auto meta_graph = unit_graph->meta_graph();
auto hgptr = CreateHeteroGraph(
meta_graph, {unit_graph}, unit_graph->NumVerticesPerType());
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
......@@ -430,7 +429,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
......@@ -474,6 +473,30 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRestrictFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetRestrictFormat();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetFormatInUse();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const std::string restrict_format = args[1];
auto hgptr = hg->GetGraphInFormat(ParseSparseFormat(restrict_format));
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
......
......@@ -334,6 +334,16 @@ class UnitGraph::COO : public BaseHeteroGraph {
return SparseFormat::kAny;
}
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for COO graph";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for COO graph";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
......@@ -375,6 +385,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return subg;
}
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override {
LOG(FATAL) << "Not enabled for COO graph.";
return nullptr;
}
aten::COOMatrix adj() const {
return adj_;
}
......@@ -709,6 +724,16 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return SparseFormat::kAny;
}
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for CSR graph";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for CSR graph";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
......@@ -730,6 +755,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return {};
}
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return nullptr;
}
aten::CSRMatrix adj() const {
return adj_;
}
......@@ -1096,6 +1126,7 @@ HeteroGraphPtr UnitGraph::CreateFromCOO(
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat));
return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
}
......@@ -1182,14 +1213,25 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format)
: BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
restrict_format_ = restrict_format;
// If the graph is hypersparse and in COO format, switch the restricted format to COO.
// If the graph is given as CSR, the indptr array is already materialized so we don't
// care about restricting conversion anyway (even if it is hypersparse).
if (restrict_format == SparseFormat::kAny) {
if (coo && coo->IsHypersparse())
restrict_format_ = SparseFormat::kCOO;
restrict_format_ = AutoDetectFormat(in_csr, out_csr, coo, restrict_format);
switch (restrict_format) {
case SparseFormat::kCSC:
in_csr_ = GetInCSR();
coo_ = nullptr;
out_csr_ = nullptr;
break;
case SparseFormat::kCSR:
out_csr_ = GetOutCSR();
coo_ = nullptr;
in_csr_ = nullptr;
break;
case SparseFormat::kCOO:
coo_ = GetCOO();
in_csr_ = nullptr;
out_csr_ = nullptr;
break;
default:
break;
}
CHECK(GetAny()) << "At least one graph structure should exist.";
......@@ -1219,50 +1261,80 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
}
UnitGraph::CSRPtr UnitGraph::GetInCSR() const {
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCSC)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSC matrix.";
CSRPtr ret = in_csr_;
if (!in_csr_) {
if (out_csr_) {
const auto& newadj = aten::CSRTranspose(out_csr_->adj());
const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret;
} else {
CHECK(coo_) << "None of CSR, COO exist";
const auto& adj = coo_->adj();
const auto& newadj = aten::COOToCSR(
aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret;
}
}
return in_csr_;
return ret;
}
/* !\brief Return out csr. If not exist, transpose the other one.*/
UnitGraph::CSRPtr UnitGraph::GetOutCSR() const {
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCSR)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSR matrix.";
CSRPtr ret = out_csr_;
if (!out_csr_) {
if (in_csr_) {
const auto& newadj = aten::CSRTranspose(in_csr_->adj());
const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
} else {
CHECK(coo_) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR(coo_->adj());
const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
}
}
return out_csr_;
return ret;
}
/* !\brief Return coo. If not exist, create from csr.*/
UnitGraph::COOPtr UnitGraph::GetCOO() const {
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCOO)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create COO matrix.";
COOPtr ret = coo_;
if (!coo_) {
if (in_csr_) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret;
} else {
CHECK(out_csr_) << "Both CSR are missing.";
const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret;
}
}
return coo_;
return ret;
}
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
......@@ -1287,6 +1359,16 @@ HeteroGraphPtr UnitGraph::GetAny() const {
}
}
dgl_format_code_t UnitGraph::GetFormatInUse() const {
dgl_format_code_t ret = 0;
if (in_csr_) ret = ret | 1;
ret = ret << 1;
if (out_csr_) ret = ret | 1;
ret = ret << 1;
if (coo_) ret = ret | 1;
return ret;
}
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) {
case SparseFormat::kCSR:
......@@ -1297,15 +1379,45 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
return GetCOO();
case SparseFormat::kAny:
return GetAny();
default:
LOG(FATAL) << "unsupported format code";
default: // SparseFormat::kAuto
LOG(FATAL) << "Must specify a restrict format.";
return nullptr;
}
}
HeteroGraphPtr UnitGraph::GetGraphInFormat(SparseFormat restrict_format) const {
int64_t num_vtypes = NumVertexTypes();
switch (restrict_format) {
case SparseFormat::kCOO:
return CreateFromCOO(
num_vtypes, GetCOO(false)->adj(), restrict_format);
case SparseFormat::kCSC:
return CreateFromCSC(
num_vtypes, GetInCSR(false)->adj(), restrict_format);
case SparseFormat::kCSR:
return CreateFromCSR(
num_vtypes, GetOutCSR(false)->adj(), restrict_format);
case SparseFormat::kAny:
return HeteroGraphPtr(
new UnitGraph(meta_graph_, in_csr_, out_csr_, coo_, restrict_format));
default: // SparseFormat::kAuto
LOG(FATAL) << "Must specify a restrict format.";
return nullptr;
}
}
SparseFormat UnitGraph::AutoDetectFormat(
CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const {
if (restrict_format != SparseFormat::kAuto)
return restrict_format;
if (coo && coo->IsHypersparse())
return SparseFormat::kCOO;
return SparseFormat::kAny;
}
SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
if (restrict_format_ != SparseFormat::kAny)
return restrict_format_;
return restrict_format_; // force to select the restricted format
else if (preferred_format != SparseFormat::kAny)
return preferred_format;
else if (in_csr_)
......
......@@ -93,8 +93,10 @@ class UnitGraph : public BaseHeteroGraph {
uint64_t NumVertices(dgl_type_t vtype) const override;
inline std::vector<int64_t> NumVerticesPerType() const override {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on unit graphs.";
return {};
std::vector<int64_t> num_nodes_per_type;
for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
num_nodes_per_type.push_back(NumVertices(vtype));
return num_nodes_per_type;
}
uint64_t NumEdges(dgl_type_t etype) const override;
......@@ -169,31 +171,31 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny);
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
......@@ -201,14 +203,29 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */
CSRPtr GetInCSR() const;
/*!
* \brief Create in-edge CSR format of the unit graph.
* \param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the in-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetInCSR(bool inplace = true) const;
/*! \return Return the out-edge CSR format. Create from other format if not exist. */
CSRPtr GetOutCSR() const;
/*!
* \brief Create out-edge CSR format of the unit graph.
* \param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the out-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetOutCSR(bool inplace = true) const;
/*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const;
/*!
* \brief Create COO format of the unit graph.
* \param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the COO format. Create from other format if not exist.
*/
COOPtr GetCOO(bool inplace = true) const;
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;
......@@ -219,10 +236,22 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
/*! \brief some heuristic rules to determine the restrict format. */
SparseFormat AutoDetectFormat(
CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const;
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
return SelectFormat(preferred_format);
}
std::string GetRestrictFormat() const override {
return ToStringSparseFormat(this->restrict_format_);
}
dgl_format_code_t GetFormatInUse() const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
/*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
......@@ -245,7 +274,7 @@ class UnitGraph : public BaseHeteroGraph {
* \param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
/*!
* \brief constructor
......@@ -264,7 +293,7 @@ class UnitGraph : public BaseHeteroGraph {
bool has_in_csr,
bool has_out_csr,
bool has_coo,
SparseFormat restrict_format = SparseFormat::kAny);
SparseFormat restrict_format = SparseFormat::kAuto);
/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;
......
......@@ -347,7 +347,7 @@ void csrwrapper_switch(DGLArgValue argval,
fn(wrapper);
} else if (argval.IsObjectType<HeteroGraphRef>()) {
HeteroGraphRef g = argval;
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr());
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0));
CHECK_NOTNULL(bgptr);
UnitGraphCSRWrapper wrapper(bgptr.get());
fn(wrapper);
......
import dgl
import backend as F
import unittest
def tree1():
"""Generate a tree
......@@ -102,6 +103,7 @@ def test_batch_unbatch1():
assert F.allclose(t2.ndata['h'], rs2.ndata['h'])
assert F.allclose(t2.edata['h'], rs2.edata['h'])
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
def test_batch_unbatch_frame():
"""Test module of node/edge frames of batched/unbatched DGLGraphs.
Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
......@@ -118,7 +120,6 @@ def test_batch_unbatch_frame():
t2.ndata['h'] = F.randn((N2, D))
t2.edata['h'] = F.randn((E2, D))
if F.backend_name != 'tensorflow': # tf's tensor is immutable
b1 = dgl.batch([t1, t2])
b2 = dgl.batch([t2])
b1.ndata['h'][:N1] = F.zeros((N1, D))
......
......@@ -1632,6 +1632,58 @@ def test_dtype_cast(index_dtype):
assert F.array_equal(g.ndata["feat"], g_cast.ndata["feat"])
assert F.array_equal(g.edata["h"], g_cast.edata["h"])
def test_format():
# single relation
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], restrict_format='coo')
assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo']
try:
spmat = g.adjacency_matrix(scipy_fmt="csr")
except:
print('test passed, graph with restrict_format coo should not create csr matrix.')
else:
assert False, 'cannot create csr when restrict_format is coo'
g1 = g.to_format('any')
assert g1.restrict_format() == 'any'
spmat = g1.adjacency_matrix(scipy_fmt='coo')
spmat = g1.adjacency_matrix(scipy_fmt='csr')
spmat = g1.adjacency_matrix(transpose=True, scipy_fmt='csr')
assert len(g1.restrict_format()) == 3
assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo']
# multiple relation
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
('developer', 'develops', 'game'): [(0, 0), (1, 1)],
}, restrict_format='csr')
user_feat = F.randn((g['follows'].number_of_src_nodes(), 5))
g['follows'].srcdata['h'] = user_feat
for rel_type in ['follows', 'plays', 'develops']:
assert g.restrict_format(rel_type) == 'csr'
print(g.format_in_use(rel_type), g.restrict_format(rel_type))
assert g.format_in_use(rel_type) == ['csr']
try:
spmat = g[rel_type].adjacency_matrix(scipy_fmt='coo')
except:
print('test passed, graph with restrict_format csr should not create coo matrix')
else:
assert False, 'cannot create coo when restrict_ormat is csr'
g1 = g.to_format('csc')
# test frame
assert F.array_equal(g1['follows'].srcdata['h'], user_feat)
# test each relation graph
for rel_type in ['follows', 'plays', 'develops']:
assert g1.restrict_format(rel_type) == 'csc'
assert g1.format_in_use(rel_type) == ['csc']
assert g.restrict_format(rel_type) == 'csr'
assert g.format_in_use(rel_type) == ['csr']
if __name__ == '__main__':
# test_create()
# test_query()
......@@ -1656,4 +1708,5 @@ if __name__ == '__main__':
# test_stack_reduce()
# test_isolated_ntype()
# test_bipartite()
test_dtype_cast()
# test_dtype_cast()
test_format()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment