"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "500b9cf184d9a6aa950ccbb9a4d967b95877196d"
Unverified Commit c2e61ce1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Restrict sparse format for DGLHeteroGraph (#1474)



* upd

* simplify

* further simplify

* lint

* doc

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* lint

* rename format

* upd

* lint

* upd

* upd

* upd

* udp

* debug

* upd

* upd

* upd

* upd

* upd

* 无可厚非吧

* 中三边肥

* 你一定要喊吗

* lint

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* Update unit_graph.h
Co-authored-by: default avatarZihao Ye <yzh119@192.168.0.110>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 1e84168e
...@@ -66,6 +66,26 @@ Querying graph structure ...@@ -66,6 +66,26 @@ Querying graph structure
DGLHeteroGraph.out_degree DGLHeteroGraph.out_degree
DGLHeteroGraph.out_degrees DGLHeteroGraph.out_degrees
Querying and manipulating sparse format
---------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLHeteroGraph.format_in_use
DGLHeteroGraph.restrict_format
DGLHeteroGraph.to_format
Querying and manipulating index data type
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLHeteroGraph.idtype
DGLHeteroGraph.long
DGLHeteroGraph.int
Using Node/edge features Using Node/edge features
------------------------ ------------------------
......
...@@ -24,6 +24,15 @@ namespace dgl { ...@@ -24,6 +24,15 @@ namespace dgl {
typedef uint64_t dgl_id_t; typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t; typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
* - z indicates whether coo is in use.
*/
typedef uint8_t dgl_format_code_t;
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
...@@ -41,7 +50,8 @@ enum class SparseFormat { ...@@ -41,7 +50,8 @@ enum class SparseFormat {
kAny = 0, kAny = 0,
kCOO = 1, kCOO = 1,
kCSR = 2, kCSR = 2,
kCSC = 3 kCSC = 3,
kAuto = 4 // kAuto is a placeholder that indicates it would be materialized later.
}; };
// Parse sparse format from string. // Parse sparse format from string.
...@@ -52,8 +62,27 @@ inline SparseFormat ParseSparseFormat(const std::string& name) { ...@@ -52,8 +62,27 @@ inline SparseFormat ParseSparseFormat(const std::string& name) {
return SparseFormat::kCSR; return SparseFormat::kCSR;
else if (name == "csc") else if (name == "csc")
return SparseFormat::kCSC; return SparseFormat::kCSC;
else else if (name == "any")
return SparseFormat::kAny; return SparseFormat::kAny;
else if (name == "auto")
return SparseFormat::kAuto;
else
LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kAny;
}
// Create string from sparse format.
inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return std::string("coo");
else if (sparse_format == SparseFormat::kCSR)
return std::string("csr");
else if (sparse_format == SparseFormat::kCSC)
return std::string("csc");
else if (sparse_format == SparseFormat::kAny)
return std::string("any");
else
return std::string("auto");
} }
// Sparse matrix object that is exposed to python API. // Sparse matrix object that is exposed to python API.
...@@ -660,7 +689,7 @@ bool COOIsNonZero(COOMatrix , int64_t row, int64_t col); ...@@ -660,7 +689,7 @@ bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
* \brief Batched implementation of COOIsNonZero. * \brief Batched implementation of COOIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/ */
runtime::NDArray COOIsNonZero(COOMatrix, runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */ /*! \brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row); int64_t COOGetRowNNZ(COOMatrix , int64_t row);
......
...@@ -368,6 +368,27 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -368,6 +368,27 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const = 0; virtual SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const = 0;
/*!
* \brief Get restrict sparse format of the graph.
*
* \return a string representing the sparse format: 'coo'/'csr'/'csc'/'any'
*/
virtual std::string GetRestrictFormat() const = 0;
/*!
* \brief Return the sparse format in use for the graph.
*
* \return a number of type dgl_format_code_t.
*/
virtual dgl_format_code_t GetFormatInUse() const = 0;
/*!
* \brief Return the graph in specified restrict format.
*
* \return The new graph.
*/
virtual HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const = 0;
/*! /*!
* \brief Get adjacency matrix in COO format. * \brief Get adjacency matrix in COO format.
* \param etype Edge type. * \param etype Edge type.
......
...@@ -939,7 +939,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -939,7 +939,7 @@ class ImmutableGraph: public GraphInterface {
/*! \return Save ImmutableGraph to stream, using out csr */ /*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
void SortCSR() { void SortCSR() override {
GetInCSR()->SortCSR(); GetInCSR()->SortCSR();
GetOutCSR()->SortCSR(); GetOutCSR()->SortCSR();
} }
......
...@@ -122,7 +122,6 @@ class ObjectRef { ...@@ -122,7 +122,6 @@ class ObjectRef {
* Compare with the two are referencing to the same object (compare by address). * Compare with the two are referencing to the same object (compare by address).
* *
* \param other Another object ref. * \param other Another object ref.
* \param other Another object ref.
* \return the compare result. * \return the compare result.
*/ */
inline bool same_as(const ObjectRef& other) const; inline bool same_as(const ObjectRef& other) const;
......
...@@ -22,7 +22,7 @@ __all__ = [ ...@@ -22,7 +22,7 @@ __all__ = [
] ]
def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True, def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True,
restrict_format='any', index_dtype="int64", **kwargs): restrict_format='auto', index_dtype="int64", **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -53,8 +53,10 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -53,8 +53,10 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. (Default: True) some time. (Default: True)
If False and card is not None, user would receive a warning. If False and card is not None, user would receive a warning.
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -150,7 +152,7 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -150,7 +152,7 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None, def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None,
validate=True, restrict_format='any', index_dtype='int64', **kwargs): validate=True, restrict_format='auto', index_dtype='int64', **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -187,8 +189,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -187,8 +189,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. (Default: True) some time. (Default: True)
If False and card is not None, user would receive a warning. If False and card is not None, user would receive a warning.
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -292,8 +296,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -292,8 +296,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
return create_from_scipy( return create_from_scipy(
data, utype, etype, vtype, restrict_format=restrict_format, index_dtype=index_dtype) data, utype, etype, vtype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite(data, utype, etype, return create_from_networkx_bipartite(data, utype, etype, vtype,
vtype, restrict_format=restrict_format, restrict_format=restrict_format,
index_dtype=index_dtype, **kwargs) index_dtype=index_dtype, **kwargs)
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
...@@ -410,7 +414,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -410,7 +414,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
retg._edge_frames[i].update(rgrh._edge_frames[0]) retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg return retg
def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'): def heterograph(data_dict, num_nodes_dict=None, restrict_format='auto', index_dtype='int64'):
"""Create a heterogeneous graph from a dictionary between edge types and edge lists. """Create a heterogeneous graph from a dictionary between edge types and edge lists.
Parameters Parameters
...@@ -428,6 +432,11 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'): ...@@ -428,6 +432,11 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
By default DGL infers the number of nodes for each node type from ``data_dict`` By default DGL infers the number of nodes for each node type from ``data_dict``
by taking the maximum node ID plus one for each node type. by taking the maximum node ID plus one for each node type.
restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
...@@ -494,12 +503,17 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'): ...@@ -494,12 +503,17 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
elif srctype == dsttype: elif srctype == dsttype:
rel_graphs.append(graph( rel_graphs.append(graph(
data, srctype, etype, data, srctype, etype,
num_nodes=num_nodes_dict[srctype], validate=False, index_dtype=index_dtype)) num_nodes=num_nodes_dict[srctype],
validate=False,
restrict_format=restrict_format,
index_dtype=index_dtype))
else: else:
rel_graphs.append(bipartite( rel_graphs.append(bipartite(
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]),
validate=False, index_dtype=index_dtype)) validate=False,
restrict_format=restrict_format,
index_dtype=index_dtype))
return hetero_from_relations(rel_graphs, num_nodes_dict) return hetero_from_relations(rel_graphs, num_nodes_dict)
...@@ -772,7 +786,7 @@ def to_homo(G): ...@@ -772,7 +786,7 @@ def to_homo(G):
############################################################ ############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True, def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True,
restrict_format="any", index_dtype='int64'): restrict_format="auto", index_dtype='int64'):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -797,8 +811,10 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -797,8 +811,10 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
maximum of the destination node IDs in the edge list plus 1. (Default: None) maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional validate : bool, optional
If True, checks if node IDs are within range. If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
...@@ -835,7 +851,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -835,7 +851,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate=True, restrict_format='any', index_dtype='int64'): validate=True, restrict_format='auto', index_dtype='int64'):
"""Internal function to create a heterograph from a list of edge tuples with types. """Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -858,8 +874,10 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, ...@@ -858,8 +874,10 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
maximum of the destination node IDs in the edge list plus 1. (Default: None) maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional validate : bool, optional
If True, checks if node IDs are within range. If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
...@@ -875,7 +893,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, ...@@ -875,7 +893,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate, restrict_format, index_dtype=index_dtype) validate, restrict_format, index_dtype=index_dtype)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
restrict_format='any', index_dtype='int64'): restrict_format='auto', index_dtype='int64'):
"""Internal function to create a heterograph from a scipy sparse matrix with types. """Internal function to create a heterograph from a scipy sparse matrix with types.
Parameters Parameters
...@@ -897,8 +915,10 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, ...@@ -897,8 +915,10 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
are always assumed to be ordered by edge ID already. are always assumed to be ordered by edge ID already.
validate : bool, optional validate : bool, optional
If True, checks if node IDs are within range. If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
...@@ -929,7 +949,8 @@ def create_from_networkx(nx_graph, ...@@ -929,7 +949,8 @@ def create_from_networkx(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any', index_dtype='int64'): restrict_format='auto',
index_dtype='int64'):
"""Create a heterograph that has only one set of nodes and edges. """Create a heterograph that has only one set of nodes and edges.
Parameters Parameters
...@@ -946,8 +967,10 @@ def create_from_networkx(nx_graph, ...@@ -946,8 +967,10 @@ def create_from_networkx(nx_graph,
Names for node features to retrieve from the NetworkX graph (Default: None) Names for node features to retrieve from the NetworkX graph (Default: None)
edge_attrs : list of str edge_attrs : list of str
Names for edge features to retrieve from the NetworkX graph (Default: None) Names for edge features to retrieve from the NetworkX graph (Default: None)
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
...@@ -1035,7 +1058,8 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1035,7 +1058,8 @@ def create_from_networkx_bipartite(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any', index_dtype='int64'): restrict_format='auto',
index_dtype='int64'):
"""Create a heterograph that has one set of source nodes, one set of """Create a heterograph that has one set of source nodes, one set of
destination nodes and one set of edges. destination nodes and one set of edges.
...@@ -1059,8 +1083,10 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1059,8 +1083,10 @@ def create_from_networkx_bipartite(nx_graph,
Names for node features to retrieve from the NetworkX graph (Default: None) Names for node features to retrieve from the NetworkX graph (Default: None)
edge_attrs : list of str edge_attrs : list of str
Names for edge features to retrieve from the NetworkX graph (Default: None) Names for edge features to retrieve from the NetworkX graph (Default: None)
restrict_format : 'any', 'coo', 'csr', 'csc', optional restrict_format : 'any', 'coo', 'csr', 'csc', 'auto' optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use). Force the storage format. Default: 'auto' (i.e. let DGL decide what to use).
index_dtype : 'int32', 'int64', optional
Force the index data type. Default: 'int64'.
Returns Returns
------- -------
......
...@@ -1099,6 +1099,11 @@ class DGLHeteroGraph(object): ...@@ -1099,6 +1099,11 @@ class DGLHeteroGraph(object):
------- -------
backend dtype object backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc. th.int32/th.int64 or tf.int32/tf.int64 etc.
See Also
--------
long
int
""" """
return getattr(F, self._graph.dtype) return getattr(F, self._graph.dtype)
...@@ -4150,6 +4155,60 @@ class DGLHeteroGraph(object): ...@@ -4150,6 +4155,60 @@ class DGLHeteroGraph(object):
"""Return if the graph is homogeneous.""" """Return if the graph is homogeneous."""
return len(self.ntypes) == 1 and len(self.etypes) == 1 return len(self.ntypes) == 1 and len(self.etypes) == 1
def format_in_use(self, etype=None, return_all=False):
"""Return the sparse formats in use of the given edge/relation type.
Returns
-------
list of string
Return all the formats currently in use (could be multiple).
See Also
--------
restrict_format
to_format
"""
return self._graph.format_in_use(self.get_etype_id(etype))
def restrict_format(self, etype=None):
"""Return the allowed sparse formats of the given edge/relation type.
Returns
-------
string : 'any', 'coo', 'csr', or 'csc'
'any' indicates all sparse formats are allowed in .
See Also
--------
format_in_use
to_format
"""
return self._graph.restrict_format(self.get_etype_id(etype))
def to_format(self, restrict_format):
"""Return a cloned graph but stored in the given restrict format.
If 'any' is given, the restrict formats of the returned graph is relaxed.
The returned graph share the same node/edge data of the original graph.
Parameters
----------
restrict_format : string
Desired restrict format ('any', 'coo', 'csr', 'csc').
Returns
-------
A new graph.
See Also
--------
format_in_use
restrict_format
"""
return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
def long(self): def long(self):
"""Return a heterograph object use int64 as index dtype, """Return a heterograph object use int64 as index dtype,
with the ndata and edata as the original object with the ndata and edata as the original object
...@@ -4169,6 +4228,7 @@ class DGLHeteroGraph(object): ...@@ -4169,6 +4228,7 @@ class DGLHeteroGraph(object):
See Also See Also
-------- --------
int int
idtype
""" """
return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes, return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
...@@ -4193,6 +4253,7 @@ class DGLHeteroGraph(object): ...@@ -4193,6 +4253,7 @@ class DGLHeteroGraph(object):
See Also See Also
-------- --------
long long
idtype
""" """
return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes, return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
......
...@@ -901,6 +901,62 @@ class HeteroGraphIndex(ObjectBase): ...@@ -901,6 +901,62 @@ class HeteroGraphIndex(ObjectBase):
rev_order = rev_csr(2) rev_order = rev_csr(2)
return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype) return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype)
def format_in_use(self, etype):
"""Return the sparse formats in use of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
list of string : return all the formats currently in use (could be multiple).
"""
format_code = _CAPI_DGLHeteroGetFormatInUse(self, etype)
ret = []
if format_code & 1:
ret.append('coo')
format_code >>= 1
if format_code & 1:
ret.append('csr')
format_code >>= 1
if format_code & 1:
ret.append('csc')
return ret
def restrict_format(self, etype):
"""Return restrict sparse format of the given edge/relation type.
Parameters
----------
etype : int
The edge/relation type.
Returns
-------
string : 'any', 'coo', 'csr', or 'csc'
"""
ret = _CAPI_DGLHeteroGetRestrictFormat(self, etype)
return ret
def to_format(self, restrict_format):
"""Return a clone graph index but stored in the given sparse format.
If 'any' is given, the restrict formats of the returned graph index
is relaxed.
Parameters
----------
restrict_format : string
Desired restrict format ('any', 'coo', 'csr', 'csc').
Returns
-------
A new graph index.
"""
return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format)
@register_object('graph.HeteroSubgraph') @register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
...@@ -942,6 +998,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -942,6 +998,7 @@ class HeteroSubgraphIndex(ObjectBase):
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret] return [utils.toindex(v.data, self.graph.dtype) for v in ret]
################################################################# #################################################################
# Creators # Creators
################################################################# #################################################################
......
...@@ -113,7 +113,7 @@ class SAGEConv(nn.Block): ...@@ -113,7 +113,7 @@ class SAGEConv(nn.Block):
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # saame as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in degrees # divide in degrees
degs = graph.in_degrees().astype(feat_dst.dtype) degs = graph.in_degrees().astype(feat_dst.dtype)
......
...@@ -246,6 +246,30 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -246,6 +246,30 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
} }
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::GetGraphInFormat(SparseFormat restrict_format) const {
std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
format_rels[etype] = relgraph->GetGraphInFormat(restrict_format);
}
return HeteroGraphPtr(new HeteroGraph(
meta_graph_, format_rels, NumVerticesPerType()));
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten( FlattenedHeteroGraphPtr HeteroGraph::Flatten(
const std::vector<dgl_type_t>& etypes) const { const std::vector<dgl_type_t>& etypes) const {
const int64_t bits = NumBits(); const int64_t bits = NumBits();
......
...@@ -187,11 +187,23 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -187,11 +187,23 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->SelectFormat(0, preferred_format); return GetRelationGraph(etype)->SelectFormat(0, preferred_format);
} }
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override; HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override; FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
GraphPtr AsImmutableGraph() const override; GraphPtr AsImmutableGraph() const override;
...@@ -205,6 +217,9 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -205,6 +217,9 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
......
...@@ -85,12 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") ...@@ -85,12 +85,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype; CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
// Test if the heterograph is a unit graph. If so, return itself. auto unit_graph = hg->GetRelationGraph(etype);
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg.sptr()); auto meta_graph = unit_graph->meta_graph();
if (bg != nullptr) auto hgptr = CreateHeteroGraph(
*rv = bg; meta_graph, {unit_graph}, unit_graph->NumVerticesPerType());
else *rv = HeteroGraphRef(hgptr);
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
...@@ -430,7 +429,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -430,7 +429,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx); HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
...@@ -474,6 +473,30 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") ...@@ -474,6 +473,30 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRestrictFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetRestrictFormat();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetFormatInUse();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const std::string restrict_format = args[1];
auto hgptr = hg->GetGraphInFormat(ParseSparseFormat(restrict_format));
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLInSubgraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -334,6 +334,16 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -334,6 +334,16 @@ class UnitGraph::COO : public BaseHeteroGraph {
return SparseFormat::kAny; return SparseFormat::kAny;
} }
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for COO graph";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for COO graph";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
...@@ -375,6 +385,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -375,6 +385,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return subg; return subg;
} }
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override {
LOG(FATAL) << "Not enabled for COO graph.";
return nullptr;
}
aten::COOMatrix adj() const { aten::COOMatrix adj() const {
return adj_; return adj_;
} }
...@@ -709,6 +724,16 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -709,6 +724,16 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return SparseFormat::kAny; return SparseFormat::kAny;
} }
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for CSR graph";
return std::string("");
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for CSR graph";
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
...@@ -730,6 +755,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -730,6 +755,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return {}; return {};
} }
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return nullptr;
}
aten::CSRMatrix adj() const { aten::CSRMatrix adj() const {
return adj_; return adj_;
} }
...@@ -1096,6 +1126,7 @@ HeteroGraphPtr UnitGraph::CreateFromCOO( ...@@ -1096,6 +1126,7 @@ HeteroGraphPtr UnitGraph::CreateFromCOO(
CHECK_EQ(mat.num_rows, mat.num_cols); CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat)); COOPtr coo(new COO(mg, mat));
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, restrict_format)); new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
} }
...@@ -1182,14 +1213,25 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { ...@@ -1182,14 +1213,25 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format) SparseFormat restrict_format)
: BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
restrict_format_ = restrict_format; restrict_format_ = AutoDetectFormat(in_csr, out_csr, coo, restrict_format);
switch (restrict_format) {
// If the graph is hypersparse and in COO format, switch the restricted format to COO. case SparseFormat::kCSC:
// If the graph is given as CSR, the indptr array is already materialized so we don't in_csr_ = GetInCSR();
// care about restricting conversion anyway (even if it is hypersparse). coo_ = nullptr;
if (restrict_format == SparseFormat::kAny) { out_csr_ = nullptr;
if (coo && coo->IsHypersparse()) break;
restrict_format_ = SparseFormat::kCOO; case SparseFormat::kCSR:
out_csr_ = GetOutCSR();
coo_ = nullptr;
in_csr_ = nullptr;
break;
case SparseFormat::kCOO:
coo_ = GetCOO();
in_csr_ = nullptr;
out_csr_ = nullptr;
break;
default:
break;
} }
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
...@@ -1219,50 +1261,80 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom( ...@@ -1219,50 +1261,80 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format)); return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
} }
UnitGraph::CSRPtr UnitGraph::GetInCSR() const { UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCSC)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSC matrix.";
CSRPtr ret = in_csr_;
if (!in_csr_) { if (!in_csr_) {
if (out_csr_) { if (out_csr_) {
const auto& newadj = aten::CSRTranspose(out_csr_->adj()); const auto& newadj = aten::CSRTranspose(out_csr_->adj());
const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret;
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const auto& adj = coo_->adj(); const auto& adj = coo_->adj();
const auto& newadj = aten::COOToCSR( const auto& newadj = aten::COOToCSR(
aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row}); aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret;
} }
} }
return in_csr_; return ret;
} }
/* !\brief Return out csr. If not exist, transpose the other one.*/ /* !\brief Return out csr. If not exist, transpose the other one.*/
UnitGraph::CSRPtr UnitGraph::GetOutCSR() const { UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCSR)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSR matrix.";
CSRPtr ret = out_csr_;
if (!out_csr_) { if (!out_csr_) {
if (in_csr_) { if (in_csr_) {
const auto& newadj = aten::CSRTranspose(in_csr_->adj()); const auto& newadj = aten::CSRTranspose(in_csr_->adj());
const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR(coo_->adj()); const auto& newadj = aten::COOToCSR(coo_->adj());
const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
} }
} }
return out_csr_; return ret;
} }
/* !\brief Return coo. If not exist, create from csr.*/ /* !\brief Return coo. If not exist, create from csr.*/
UnitGraph::COOPtr UnitGraph::GetCOO() const { UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace)
if (restrict_format_ != SparseFormat::kAny &&
restrict_format_ != SparseFormat::kCOO)
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create COO matrix.";
COOPtr ret = coo_;
if (!coo_) { if (!coo_) {
if (in_csr_) { if (in_csr_) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true)); const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj); ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret;
} else { } else {
CHECK(out_csr_) << "Both CSR are missing."; CHECK(out_csr_) << "Both CSR are missing.";
const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true); const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj); ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret;
} }
} }
return coo_; return ret;
} }
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const { aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
...@@ -1287,6 +1359,16 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1287,6 +1359,16 @@ HeteroGraphPtr UnitGraph::GetAny() const {
} }
} }
dgl_format_code_t UnitGraph::GetFormatInUse() const {
dgl_format_code_t ret = 0;
if (in_csr_) ret = ret | 1;
ret = ret << 1;
if (out_csr_) ret = ret | 1;
ret = ret << 1;
if (coo_) ret = ret | 1;
return ret;
}
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) { switch (format) {
case SparseFormat::kCSR: case SparseFormat::kCSR:
...@@ -1297,15 +1379,45 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { ...@@ -1297,15 +1379,45 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
return GetCOO(); return GetCOO();
case SparseFormat::kAny: case SparseFormat::kAny:
return GetAny(); return GetAny();
default: default: // SparseFormat::kAuto
LOG(FATAL) << "unsupported format code"; LOG(FATAL) << "Must specify a restrict format.";
return nullptr; return nullptr;
} }
} }
HeteroGraphPtr UnitGraph::GetGraphInFormat(SparseFormat restrict_format) const {
int64_t num_vtypes = NumVertexTypes();
switch (restrict_format) {
case SparseFormat::kCOO:
return CreateFromCOO(
num_vtypes, GetCOO(false)->adj(), restrict_format);
case SparseFormat::kCSC:
return CreateFromCSC(
num_vtypes, GetInCSR(false)->adj(), restrict_format);
case SparseFormat::kCSR:
return CreateFromCSR(
num_vtypes, GetOutCSR(false)->adj(), restrict_format);
case SparseFormat::kAny:
return HeteroGraphPtr(
new UnitGraph(meta_graph_, in_csr_, out_csr_, coo_, restrict_format));
default: // SparseFormat::kAuto
LOG(FATAL) << "Must specify a restrict format.";
return nullptr;
}
}
SparseFormat UnitGraph::AutoDetectFormat(
CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const {
if (restrict_format != SparseFormat::kAuto)
return restrict_format;
if (coo && coo->IsHypersparse())
return SparseFormat::kCOO;
return SparseFormat::kAny;
}
SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
if (restrict_format_ != SparseFormat::kAny) if (restrict_format_ != SparseFormat::kAny)
return restrict_format_; return restrict_format_; // force to select the restricted format
else if (preferred_format != SparseFormat::kAny) else if (preferred_format != SparseFormat::kAny)
return preferred_format; return preferred_format;
else if (in_csr_) else if (in_csr_)
......
...@@ -93,8 +93,10 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -93,8 +93,10 @@ class UnitGraph : public BaseHeteroGraph {
uint64_t NumVertices(dgl_type_t vtype) const override; uint64_t NumVertices(dgl_type_t vtype) const override;
inline std::vector<int64_t> NumVerticesPerType() const override { inline std::vector<int64_t> NumVerticesPerType() const override {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on unit graphs."; std::vector<int64_t> num_nodes_per_type;
return {}; for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
num_nodes_per_type.push_back(NumVertices(vtype));
return num_nodes_per_type;
} }
uint64_t NumEdges(dgl_type_t etype) const override; uint64_t NumEdges(dgl_type_t etype) const override;
...@@ -169,31 +171,31 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -169,31 +171,31 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph from COO arrays */ /*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny); IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Create a graph from (out) CSR arrays */ /*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Create a graph from (in) CSC arrays */ /*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
...@@ -201,14 +203,29 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -201,14 +203,29 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */ /*!
CSRPtr GetInCSR() const; * \brief Create in-edge CSR format of the unit graph.
* \param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the in-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetInCSR(bool inplace = true) const;
/*! \return Return the out-edge CSR format. Create from other format if not exist. */ /*!
CSRPtr GetOutCSR() const; * \brief Create out-edge CSR format of the unit graph.
* \param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the out-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetOutCSR(bool inplace = true) const;
/*! \return Return the COO format. Create from other format if not exist. */ /*!
COOPtr GetCOO() const; * \brief Create COO format of the unit graph.
* \param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the COO format. Create from other format if not exist.
*/
COOPtr GetCOO(bool inplace = true) const;
/*! \return Return the COO matrix form */ /*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override; aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;
...@@ -219,10 +236,22 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -219,10 +236,22 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return the out-edge CSR in the matrix form */ /*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override; aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
/*! \brief some heuristic rules to determine the restrict format. */
SparseFormat AutoDetectFormat(
CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const;
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override { SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
return SelectFormat(preferred_format); return SelectFormat(preferred_format);
} }
std::string GetRestrictFormat() const override {
return ToStringSparseFormat(this->restrict_format_);
}
dgl_format_code_t GetFormatInUse() const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
/*! \return Load UnitGraph from stream, using CSRMatrix*/ /*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs); bool Load(dmlc::Stream* fs);
...@@ -245,7 +274,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -245,7 +274,7 @@ class UnitGraph : public BaseHeteroGraph {
* \param coo coo * \param coo coo
*/ */
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
/*! /*!
* \brief constructor * \brief constructor
...@@ -264,7 +293,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -264,7 +293,7 @@ class UnitGraph : public BaseHeteroGraph {
bool has_in_csr, bool has_in_csr,
bool has_out_csr, bool has_out_csr,
bool has_coo, bool has_coo,
SparseFormat restrict_format = SparseFormat::kAny); SparseFormat restrict_format = SparseFormat::kAuto);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
......
...@@ -347,7 +347,7 @@ void csrwrapper_switch(DGLArgValue argval, ...@@ -347,7 +347,7 @@ void csrwrapper_switch(DGLArgValue argval,
fn(wrapper); fn(wrapper);
} else if (argval.IsObjectType<HeteroGraphRef>()) { } else if (argval.IsObjectType<HeteroGraphRef>()) {
HeteroGraphRef g = argval; HeteroGraphRef g = argval;
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr()); auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0));
CHECK_NOTNULL(bgptr); CHECK_NOTNULL(bgptr);
UnitGraphCSRWrapper wrapper(bgptr.get()); UnitGraphCSRWrapper wrapper(bgptr.get());
fn(wrapper); fn(wrapper);
......
import dgl import dgl
import backend as F import backend as F
import unittest
def tree1(): def tree1():
"""Generate a tree """Generate a tree
...@@ -102,6 +103,7 @@ def test_batch_unbatch1(): ...@@ -102,6 +103,7 @@ def test_batch_unbatch1():
assert F.allclose(t2.ndata['h'], rs2.ndata['h']) assert F.allclose(t2.ndata['h'], rs2.ndata['h'])
assert F.allclose(t2.edata['h'], rs2.edata['h']) assert F.allclose(t2.edata['h'], rs2.edata['h'])
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
def test_batch_unbatch_frame(): def test_batch_unbatch_frame():
"""Test module of node/edge frames of batched/unbatched DGLGraphs. """Test module of node/edge frames of batched/unbatched DGLGraphs.
Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475. Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
...@@ -118,26 +120,25 @@ def test_batch_unbatch_frame(): ...@@ -118,26 +120,25 @@ def test_batch_unbatch_frame():
t2.ndata['h'] = F.randn((N2, D)) t2.ndata['h'] = F.randn((N2, D))
t2.edata['h'] = F.randn((E2, D)) t2.edata['h'] = F.randn((E2, D))
if F.backend_name != 'tensorflow': # tf's tensor is immutable b1 = dgl.batch([t1, t2])
b1 = dgl.batch([t1, t2]) b2 = dgl.batch([t2])
b2 = dgl.batch([t2]) b1.ndata['h'][:N1] = F.zeros((N1, D))
b1.ndata['h'][:N1] = F.zeros((N1, D)) b1.edata['h'][:E1] = F.zeros((E1, D))
b1.edata['h'][:E1] = F.zeros((E1, D)) b2.ndata['h'][:N2] = F.zeros((N2, D))
b2.ndata['h'][:N2] = F.zeros((N2, D)) b2.edata['h'][:E2] = F.zeros((E2, D))
b2.edata['h'][:E2] = F.zeros((E2, D)) assert not F.allclose(t1.ndata['h'], F.zeros((N1, D)))
assert not F.allclose(t1.ndata['h'], F.zeros((N1, D))) assert not F.allclose(t1.edata['h'], F.zeros((E1, D)))
assert not F.allclose(t1.edata['h'], F.zeros((E1, D))) assert not F.allclose(t2.ndata['h'], F.zeros((N2, D)))
assert not F.allclose(t2.ndata['h'], F.zeros((N2, D))) assert not F.allclose(t2.edata['h'], F.zeros((E2, D)))
assert not F.allclose(t2.edata['h'], F.zeros((E2, D)))
g1, g2 = dgl.unbatch(b1)
g1, g2 = dgl.unbatch(b1) _g2, = dgl.unbatch(b2)
_g2, = dgl.unbatch(b2) assert F.allclose(g1.ndata['h'], F.zeros((N1, D)))
assert F.allclose(g1.ndata['h'], F.zeros((N1, D))) assert F.allclose(g1.edata['h'], F.zeros((E1, D)))
assert F.allclose(g1.edata['h'], F.zeros((E1, D))) assert F.allclose(g2.ndata['h'], t2.ndata['h'])
assert F.allclose(g2.ndata['h'], t2.ndata['h']) assert F.allclose(g2.edata['h'], t2.edata['h'])
assert F.allclose(g2.edata['h'], t2.edata['h']) assert F.allclose(_g2.ndata['h'], F.zeros((N2, D)))
assert F.allclose(_g2.ndata['h'], F.zeros((N2, D))) assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))
assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))
def test_batch_unbatch2(): def test_batch_unbatch2():
# test setting/getting features after batch # test setting/getting features after batch
......
...@@ -1632,6 +1632,58 @@ def test_dtype_cast(index_dtype): ...@@ -1632,6 +1632,58 @@ def test_dtype_cast(index_dtype):
assert F.array_equal(g.ndata["feat"], g_cast.ndata["feat"]) assert F.array_equal(g.ndata["feat"], g_cast.ndata["feat"])
assert F.array_equal(g.edata["h"], g_cast.edata["h"]) assert F.array_equal(g.edata["h"], g_cast.edata["h"])
def test_format():
# single relation
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], restrict_format='coo')
assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo']
try:
spmat = g.adjacency_matrix(scipy_fmt="csr")
except:
print('test passed, graph with restrict_format coo should not create csr matrix.')
else:
assert False, 'cannot create csr when restrict_format is coo'
g1 = g.to_format('any')
assert g1.restrict_format() == 'any'
spmat = g1.adjacency_matrix(scipy_fmt='coo')
spmat = g1.adjacency_matrix(scipy_fmt='csr')
spmat = g1.adjacency_matrix(transpose=True, scipy_fmt='csr')
assert len(g1.restrict_format()) == 3
assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo']
# multiple relation
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
('developer', 'develops', 'game'): [(0, 0), (1, 1)],
}, restrict_format='csr')
user_feat = F.randn((g['follows'].number_of_src_nodes(), 5))
g['follows'].srcdata['h'] = user_feat
for rel_type in ['follows', 'plays', 'develops']:
assert g.restrict_format(rel_type) == 'csr'
print(g.format_in_use(rel_type), g.restrict_format(rel_type))
assert g.format_in_use(rel_type) == ['csr']
try:
spmat = g[rel_type].adjacency_matrix(scipy_fmt='coo')
except:
print('test passed, graph with restrict_format csr should not create coo matrix')
else:
assert False, 'cannot create coo when restrict_ormat is csr'
g1 = g.to_format('csc')
# test frame
assert F.array_equal(g1['follows'].srcdata['h'], user_feat)
# test each relation graph
for rel_type in ['follows', 'plays', 'develops']:
assert g1.restrict_format(rel_type) == 'csc'
assert g1.format_in_use(rel_type) == ['csc']
assert g.restrict_format(rel_type) == 'csr'
assert g.format_in_use(rel_type) == ['csr']
if __name__ == '__main__': if __name__ == '__main__':
# test_create() # test_create()
# test_query() # test_query()
...@@ -1656,4 +1708,5 @@ if __name__ == '__main__': ...@@ -1656,4 +1708,5 @@ if __name__ == '__main__':
# test_stack_reduce() # test_stack_reduce()
# test_isolated_ntype() # test_isolated_ntype()
# test_bipartite() # test_bipartite()
test_dtype_cast() # test_dtype_cast()
test_format()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment