Unverified Commit f5330cb6 authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Sparse] Clean formats function docstring. (#5851)

update the DGLGraph.formats docstring
parent 72ec1c95
...@@ -6066,14 +6066,20 @@ class DGLGraph(object): ...@@ -6066,14 +6066,20 @@ class DGLGraph(object):
self._edge_frames = old_eframes self._edge_frames = old_eframes
def formats(self, formats=None): def formats(self, formats=None):
r"""Get a cloned graph with the specified sparse format(s) or query r"""Get a cloned graph with the specified allowed sparse format(s) or
for the usage status of sparse formats query for the usage status of sparse formats.
The API copies both the graph structure and the features. The API copies both the graph structure and the features.
If the input graph has multiple edge types, they will have the same If the input graph has multiple edge types, they will have the same
sparse format. sparse format.
When ``formats`` is not None, if the intersection between `formats` and
the current graph's created sparse format(s) is not empty, the returned
cloned graph only retains all sparse format(s) in the intersection. If
the intersection is empty, a sparse format will be selected to be
created following the order of ``'coo' -> 'csr' -> 'csc'``.
Parameters Parameters
---------- ----------
formats : str or list of str or None formats : str or list of str or None
...@@ -6089,7 +6095,8 @@ class DGLGraph(object): ...@@ -6089,7 +6095,8 @@ class DGLGraph(object):
* If formats is None, the result will be a dict recording the usage * If formats is None, the result will be a dict recording the usage
status of sparse formats. status of sparse formats.
* Otherwise, a DGLGraph will be returned, which is a clone of the * Otherwise, a DGLGraph will be returned, which is a clone of the
original graph with the specified sparse format(s) ``formats``. original graph with the specified allowed sparse format(s)
``formats``.
Examples Examples
-------- --------
...@@ -6103,15 +6110,15 @@ class DGLGraph(object): ...@@ -6103,15 +6110,15 @@ class DGLGraph(object):
>>> g = dgl.graph(([0, 0, 1], [2, 3, 2])) >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))
>>> g.ndata['h'] = torch.ones(4, 1) >>> g.ndata['h'] = torch.ones(4, 1)
>>> # Check status of format usage >>> # Check status of format usage.
>>> g.formats() >>> g.formats()
{'created': ['coo'], 'not created': ['csr', 'csc']} {'created': ['coo'], 'not created': ['csr', 'csc']}
>>> # Get a clone of the graph with 'csr' format >>> # Get a clone of the graph with 'csr' format.
>>> csr_g = g.formats('csr') >>> csr_g = g.formats('csr')
>>> # Only allowed formats will be displayed in the status query >>> # Only allowed formats will be displayed in the status query.
>>> csr_g.formats() >>> csr_g.formats()
{'created': ['csr'], 'not created': []} {'created': ['csr'], 'not created': []}
>>> # Features are copied as well >>> # Features are copied as well.
>>> csr_g.ndata['h'] >>> csr_g.ndata['h']
tensor([[1.], tensor([[1.],
[1.], [1.],
...@@ -6128,17 +6135,43 @@ class DGLGraph(object): ...@@ -6128,17 +6135,43 @@ class DGLGraph(object):
... }) ... })
>>> g.formats() >>> g.formats()
{'created': ['coo'], 'not created': ['csr', 'csc']} {'created': ['coo'], 'not created': ['csr', 'csc']}
>>> # Get a clone of the graph with 'csr' format >>> # Get a clone of the graph with 'csr' format.
>>> csr_g = g.formats('csr') >>> csr_g = g.formats('csr')
>>> # Only allowed formats will be displayed in the status query >>> # Only allowed formats will be displayed in the status query.
>>> csr_g.formats() >>> csr_g.formats()
{'created': ['csr'], 'not created': []} {'created': ['csr'], 'not created': []}
**When formats intersects with created formats**
>>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))
>>> g = g.formats(['coo', 'csr'])
>>> g.create_formats_()
>>> g.formats()
{'created': ['coo', 'csr'], 'not created': []}
>>> # Get a clone of the graph allowed formats 'csr' and 'csc'.
>>> csr_csc_g = g.formats(['csr', 'csc'])
>>> # Only the intersection 'csr' will be retained.
>>> csr_csc_g.formats()
{'created': ['csr'], 'not created': ['csc']}
**When formats doesn't intersect with created formats**
>>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))
>>> g = g.formats('coo')
>>> g.formats()
{'created': ['coo'], 'not created': []}
>>> # Get a clone of the graph allowed formats 'csr' and 'csc'.
>>> csr_csc_g = g.formats(['csr', 'csc'])
>>> # Since the intersection is empty, 'csr' will be created as it is
>>> # first in the order of 'coo' -> 'csr' -> 'csc'.
>>> csr_csc_g.formats()
{'created': ['csr'], 'not created': ['csc']}
""" """
if formats is None: if formats is None:
# Return the format information # Return the format information.
return self._graph.formats() return self._graph.formats()
else: else:
# Convert the graph to use another format # Convert the graph to use another allowed format.
ret = copy.copy(self) ret = copy.copy(self)
ret._graph = self._graph.formats(formats) ret._graph = self._graph.formats(formats)
return ret return ret
......
...@@ -1102,12 +1102,18 @@ class HeteroGraphIndex(ObjectBase): ...@@ -1102,12 +1102,18 @@ class HeteroGraphIndex(ObjectBase):
) )
def formats(self, formats=None): def formats(self, formats=None):
"""Get a graph index with the specified sparse format(s) or query """Get a graph index with the specified allowed sparse format(s) or
for the usage status of sparse formats query for the usage status of sparse formats.
If the graph has multiple edge types, they will have the same If the graph has multiple edge types, they will have the same
sparse format. sparse format.
When ``formats`` is not None, if the intersection between `formats` and
the current graph's created sparse format(s) is not empty, the returned
cloned graph only retains all sparse format(s) in the intersection. If
the intersection is empty, a sparse format will be selected to be
created following the order of ``'coo' -> 'csr' -> 'csc'``.
Parameters Parameters
---------- ----------
formats : str or list of str or None formats : str or list of str or None
...@@ -1123,7 +1129,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -1123,7 +1129,8 @@ class HeteroGraphIndex(ObjectBase):
* If formats is None, the result will be a dict recording the usage * If formats is None, the result will be a dict recording the usage
status of sparse formats. status of sparse formats.
* Otherwise, a GraphIndex will be returned, which is a clone of the * Otherwise, a GraphIndex will be returned, which is a clone of the
original graph with the specified sparse format(s) ``formats``. original graph with the specified allowed sparse format(s)
``formats``.
""" """
formats_allowed = _CAPI_DGLHeteroGetAllowedFormats(self) formats_allowed = _CAPI_DGLHeteroGetAllowedFormats(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment