Unverified Commit 70dc2ee9 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Add a request_format api and enrich related docstring. (#1528)

* upd

* upd

* lint

* fix

* bloody lint
parent ccd4eab0
...@@ -4167,17 +4167,50 @@ class DGLHeteroGraph(object): ...@@ -4167,17 +4167,50 @@ class DGLHeteroGraph(object):
"""Return if the graph is homogeneous.""" """Return if the graph is homogeneous."""
return len(self.ntypes) == 1 and len(self.etypes) == 1 return len(self.ntypes) == 1 and len(self.etypes) == 1
def format_in_use(self, etype=None, return_all=False): def format_in_use(self, etype=None):
"""Return the sparse formats in use of the given edge/relation type. """Return the sparse formats in use of the given edge/relation type.
Parameters
----------
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
list of string list of str
Return all the formats currently in use (could be multiple). Return all the formats currently in use (could be multiple).
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.format_in_use()
['csr']
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
>>> spmat = g['develops'].adjacency_matrix(
... transpose=True, scipy_fmt='csr') // Create CSR representation.
>>> g.format_in_use('develops')
['coo', 'csr']
which is equivalent to:
>>> g['develops'].restrict_format()
['coo', 'csr']
See Also See Also
-------- --------
restrict_format restrict_format
request_format
to_format to_format
""" """
return self._graph.format_in_use(self.get_etype_id(etype)) return self._graph.format_in_use(self.get_etype_id(etype))
...@@ -4185,37 +4218,152 @@ class DGLHeteroGraph(object): ...@@ -4185,37 +4218,152 @@ class DGLHeteroGraph(object):
def restrict_format(self, etype=None): def restrict_format(self, etype=None):
"""Return the allowed sparse formats of the given edge/relation type. """Return the allowed sparse formats of the given edge/relation type.
Parameters
----------
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns Returns
------- -------
string : 'any', 'coo', 'csr', or 'csc' str : ``'any'``, ``'coo'``, ``'csr'``, or ``'csc'``
'any' indicates all sparse formats are allowed in . ``'any'`` indicates all sparse formats are allowed in .
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.restrict_format()
'csr'
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.restrict_format('develops')
'any'
which is equivalent to:
>>> g['develops'].restrict_format()
'any'
See Also See Also
-------- --------
format_in_use format_in_use
request_format
to_format to_format
""" """
return self._graph.restrict_format(self.get_etype_id(etype)) return self._graph.restrict_format(self.get_etype_id(etype))
def request_format(self, sparse_format, etype=None):
"""Create a sparse matrix representation in given format immediately.
When the restrict format of the given edge type is ``any``, all formats of
sparse matrix representation are created in demand. In some cases user may
want a sparse matrix representation to be created immediately (e.g. in a
multi-process data loader), this API is designed for such purpose.
Parameters
----------
sparse_format : str
``'coo'``, ``'csr'``, or ``'csc'``
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='any')
>>> g.format_in_use()
['coo']
>>> g.request_format('csr')
>>> g.format_in_use()
['coo', 'csr']
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
>>> g.request_format('csc', etype='develops')
>>> g.format_in_use('develops')
['coo', 'csc']
Another way to request format for a given etype is:
>>> g['plays'].request_format('csr')
>>> g['plays'].format_in_use()
['coo', 'csr']
See Also
--------
format_in_use
restrict_format
to_format
"""
if self.restrict_format(etype) != 'any':
raise KeyError("request_format is only available for "
"graph whose restrict_format is 'any'")
if not sparse_format in ['coo', 'csr', 'csc']:
raise KeyError("can only request coo/csr/csr.")
return self._graph.request_format(sparse_format, self.get_etype_id(etype))
def to_format(self, restrict_format): def to_format(self, restrict_format):
"""Return a cloned graph but stored in the given restrict format. """Return a cloned graph but stored in the given restrict format.
If 'any' is given, the restrict formats of the returned graph is relaxed. If ``'any'`` is given, the restrict formats of the returned graph is relaxed.
The returned graph share the same node/edge data of the original graph. The returned graph share the same node/edge data of the original graph.
Parameters Parameters
---------- ----------
restrict_format : string restrict_format : str
Desired restrict format ('any', 'coo', 'csr', 'csc'). Desired restrict format (``'any'``, ``'coo'``, ``'csr'``, ``'csc'``).
Returns Returns
------- -------
A new graph. A new graph.
Examples
--------
For a graph with single edge type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.ndata['h'] = th.ones(3, 3)
>>> g.restrict_format()
'csr'
>>> g1 = g.to_format('coo')
>>> g1.ndata
{'h': tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])}
>>> g1.restrict_format()
'coo'
For a graph with multiple edge types:
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='coo')
>>> g.restrict_format('develops')
'coo'
>>> g1 = g.to_format('any')
>>> g1.restrict_format('plays')
'any'
See Also See Also
-------- --------
format_in_use format_in_use
restrict_format restrict_format
request_format
""" """
return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes, return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
......
...@@ -956,11 +956,23 @@ class HeteroGraphIndex(ObjectBase): ...@@ -956,11 +956,23 @@ class HeteroGraphIndex(ObjectBase):
Returns Returns
------- -------
string : 'any', 'coo', 'csr', or 'csc' string : ``'any'``, ``'coo'``, ``'csr'``, or ``'csc'``
""" """
ret = _CAPI_DGLHeteroGetRestrictFormat(self, etype) ret = _CAPI_DGLHeteroGetRestrictFormat(self, etype)
return ret return ret
def request_format(self, sparse_format, etype):
"""Create a sparse matrix representation in given format immediately.
Parameters
----------
etype : int
The edge/relation type.
sparse_format : str
``'coo'``, ``'csr'``, or ``'csc'``
"""
_CAPI_DGLHeteroRequestFormat(self, sparse_format, etype)
def to_format(self, restrict_format): def to_format(self, restrict_format):
"""Return a clone graph index but stored in the given sparse format. """Return a clone graph index but stored in the given sparse format.
...@@ -969,8 +981,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -969,8 +981,8 @@ class HeteroGraphIndex(ObjectBase):
Parameters Parameters
---------- ----------
restrict_format : string restrict_format : str
Desired restrict format ('any', 'coo', 'csr', 'csc'). Desired restrict format (``'any'``, ``'coo'``, ``'csr'``, ``'csc'``).
Returns Returns
------- -------
......
...@@ -489,6 +489,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse") ...@@ -489,6 +489,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
*rv = hg->GetRelationGraph(etype)->GetFormatInUse(); *rv = hg->GetRelationGraph(etype)->GetFormatInUse();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRequestFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const std::string sparse_format = args[1];
dgl_type_t etype = args[2];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
bg->GetFormat(ParseSparseFormat(sparse_format));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -248,6 +248,14 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -248,6 +248,14 @@ class UnitGraph : public BaseHeteroGraph {
return ToStringSparseFormat(this->restrict_format_); return ToStringSparseFormat(this->restrict_format_);
} }
/*!
* \brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
* \return A graph in the requested format.
*/
HeteroGraphPtr GetFormat(SparseFormat format) const;
dgl_format_code_t GetFormatInUse() const override; dgl_format_code_t GetFormatInUse() const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override; HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
...@@ -298,14 +306,6 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -298,14 +306,6 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
/*!
* \brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
* \return A graph in the requested format.
*/
HeteroGraphPtr GetFormat(SparseFormat format) const;
/*! /*!
* \brief Determine which format to use with a preference. * \brief Determine which format to use with a preference.
* *
......
...@@ -1637,20 +1637,18 @@ def test_format(): ...@@ -1637,20 +1637,18 @@ def test_format():
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], restrict_format='coo') g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], restrict_format='coo')
assert g.restrict_format() == 'coo' assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo'] assert g.format_in_use() == ['coo']
try: try:
spmat = g.adjacency_matrix(scipy_fmt="csr") spmat = g.adjacency_matrix(scipy_fmt="csr")
except: except:
print('test passed, graph with restrict_format coo should not create csr matrix.') print('test passed, graph with restrict_format coo should not create csr matrix.')
else: else:
assert False, 'cannot create csr when restrict_format is coo' assert False, 'cannot create csr when restrict_format is coo'
g1 = g.to_format('any') g1 = g.to_format('any')
assert g1.restrict_format() == 'any' assert g1.restrict_format() == 'any'
spmat = g1.adjacency_matrix(scipy_fmt='coo') g1.request_format('coo')
spmat = g1.adjacency_matrix(scipy_fmt='csr') g1.request_format('csr')
spmat = g1.adjacency_matrix(transpose=True, scipy_fmt='csr') g1.request_format('csc')
assert len(g1.restrict_format()) == 3 assert len(g1.format_in_use()) == 3
assert g.restrict_format() == 'coo' assert g.restrict_format() == 'coo'
assert g.format_in_use() == ['coo'] assert g.format_in_use() == ['coo']
...@@ -1664,15 +1662,13 @@ def test_format(): ...@@ -1664,15 +1662,13 @@ def test_format():
g['follows'].srcdata['h'] = user_feat g['follows'].srcdata['h'] = user_feat
for rel_type in ['follows', 'plays', 'develops']: for rel_type in ['follows', 'plays', 'develops']:
assert g.restrict_format(rel_type) == 'csr' assert g.restrict_format(rel_type) == 'csr'
print(g.format_in_use(rel_type), g.restrict_format(rel_type))
assert g.format_in_use(rel_type) == ['csr'] assert g.format_in_use(rel_type) == ['csr']
try: try:
spmat = g[rel_type].adjacency_matrix(scipy_fmt='coo') g[rel_type].request_format('coo')
except: except:
print('test passed, graph with restrict_format csr should not create coo matrix') print('test passed, graph with restrict_format csr should not create coo matrix')
else: else:
assert False, 'cannot create coo when restrict_ormat is csr' assert False, 'cannot create coo when restrict_format is csr'
g1 = g.to_format('csc') g1 = g.to_format('csc')
# test frame # test frame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment