Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
70dc2ee9
Unverified
Commit
70dc2ee9
authored
May 16, 2020
by
Zihao Ye
Committed by
GitHub
May 16, 2020
Browse files
[Feature] Add a request_format api and enrich related docstring. (#1528)
* upd * upd * lint * fix * bloody lint
parent
ccd4eab0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
194 additions
and
28 deletions
+194
-28
python/dgl/heterograph.py
python/dgl/heterograph.py
+155
-7
python/dgl/heterograph_index.py
python/dgl/heterograph_index.py
+15
-3
src/graph/heterograph_capi.cc
src/graph/heterograph_capi.cc
+10
-0
src/graph/unit_graph.h
src/graph/unit_graph.h
+8
-8
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+6
-10
No files found.
python/dgl/heterograph.py
View file @
70dc2ee9
...
@@ -4167,17 +4167,50 @@ class DGLHeteroGraph(object):
...
@@ -4167,17 +4167,50 @@ class DGLHeteroGraph(object):
"""Return if the graph is homogeneous."""
"""Return if the graph is homogeneous."""
return
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
return
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
def
format_in_use
(
self
,
etype
=
None
,
return_all
=
False
):
def
format_in_use
(
self
,
etype
=
None
):
"""Return the sparse formats in use of the given edge/relation type.
"""Return the sparse formats in use of the given edge/relation type.
Parameters
----------
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
Returns
-------
-------
list of str
ing
list of str
Return all the formats currently in use (could be multiple).
Return all the formats currently in use (could be multiple).
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.format_in_use()
['csr']
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
>>> spmat = g['develops'].adjacency_matrix(
... transpose=True, scipy_fmt='csr') // Create CSR representation.
>>> g.format_in_use('develops')
['coo', 'csr']
which is equivalent to:
>>> g['develops'].restrict_format()
['coo', 'csr']
See Also
See Also
--------
--------
restrict_format
restrict_format
request_format
to_format
to_format
"""
"""
return
self
.
_graph
.
format_in_use
(
self
.
get_etype_id
(
etype
))
return
self
.
_graph
.
format_in_use
(
self
.
get_etype_id
(
etype
))
...
@@ -4185,37 +4218,152 @@ class DGLHeteroGraph(object):
...
@@ -4185,37 +4218,152 @@ class DGLHeteroGraph(object):
def
restrict_format
(
self
,
etype
=
None
):
def
restrict_format
(
self
,
etype
=
None
):
"""Return the allowed sparse formats of the given edge/relation type.
"""Return the allowed sparse formats of the given edge/relation type.
Parameters
----------
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
Returns
-------
-------
string : 'any', 'coo', 'csr', or 'csc'
str : ``'any'``, ``'coo'``, ``'csr'``, or ``'csc'``
'any' indicates all sparse formats are allowed in .
``'any'`` indicates all sparse formats are allowed in .
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.restrict_format()
'csr'
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.restrict_format('develops')
'any'
which is equivalent to:
>>> g['develops'].restrict_format()
'any'
See Also
See Also
--------
--------
format_in_use
format_in_use
request_format
to_format
to_format
"""
"""
return
self
.
_graph
.
restrict_format
(
self
.
get_etype_id
(
etype
))
return
self
.
_graph
.
restrict_format
(
self
.
get_etype_id
(
etype
))
def
request_format
(
self
,
sparse_format
,
etype
=
None
):
"""Create a sparse matrix representation in given format immediately.
When the restrict format of the given edge type is ``any``, all formats of
sparse matrix representation are created in demand. In some cases user may
want a sparse matrix representation to be created immediately (e.g. in a
multi-process data loader), this API is designed for such purpose.
Parameters
----------
sparse_format : str
``'coo'``, ``'csr'``, or ``'csc'``
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Examples
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='any')
>>> g.format_in_use()
['coo']
>>> g.request_format('csr')
>>> g.format_in_use()
['coo', 'csr']
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
>>> g.request_format('csc', etype='develops')
>>> g.format_in_use('develops')
['coo', 'csc']
Another way to request format for a given etype is:
>>> g['plays'].request_format('csr')
>>> g['plays'].format_in_use()
['coo', 'csr']
See Also
--------
format_in_use
restrict_format
to_format
"""
if
self
.
restrict_format
(
etype
)
!=
'any'
:
raise
KeyError
(
"request_format is only available for "
"graph whose restrict_format is 'any'"
)
if
not
sparse_format
in
[
'coo'
,
'csr'
,
'csc'
]:
raise
KeyError
(
"can only request coo/csr/csr."
)
return
self
.
_graph
.
request_format
(
sparse_format
,
self
.
get_etype_id
(
etype
))
def
to_format
(
self
,
restrict_format
):
def
to_format
(
self
,
restrict_format
):
"""Return a cloned graph but stored in the given restrict format.
"""Return a cloned graph but stored in the given restrict format.
If 'any' is given, the restrict formats of the returned graph is relaxed.
If
``
'any'
``
is given, the restrict formats of the returned graph is relaxed.
The returned graph share the same node/edge data of the original graph.
The returned graph share the same node/edge data of the original graph.
Parameters
Parameters
----------
----------
restrict_format : str
ing
restrict_format : str
Desired restrict format ('any'
, 'coo', 'csr',
'csc').
Desired restrict format (
``
'any'
``, ``'coo'``, ``'csr'``, ``
'csc'
``
).
Returns
Returns
-------
-------
A new graph.
A new graph.
Examples
--------
For a graph with single edge type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g.ndata['h'] = th.ones(3, 3)
>>> g.restrict_format()
'csr'
>>> g1 = g.to_format('coo')
>>> g1.ndata
{'h': tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])}
>>> g1.restrict_format()
'coo'
For a graph with multiple edge types:
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... }, restrict_format='coo')
>>> g.restrict_format('develops')
'coo'
>>> g1 = g.to_format('any')
>>> g1.restrict_format('plays')
'any'
See Also
See Also
--------
--------
format_in_use
format_in_use
restrict_format
restrict_format
request_format
"""
"""
return
DGLHeteroGraph
(
self
.
_graph
.
to_format
(
restrict_format
),
self
.
ntypes
,
self
.
etypes
,
return
DGLHeteroGraph
(
self
.
_graph
.
to_format
(
restrict_format
),
self
.
ntypes
,
self
.
etypes
,
self
.
_node_frames
,
self
.
_node_frames
,
...
...
python/dgl/heterograph_index.py
View file @
70dc2ee9
...
@@ -956,11 +956,23 @@ class HeteroGraphIndex(ObjectBase):
...
@@ -956,11 +956,23 @@ class HeteroGraphIndex(ObjectBase):
Returns
Returns
-------
-------
string : 'any'
, 'coo',
'csr', or 'csc'
string :
``
'any'
``, ``'coo'``, ``
'csr'
``
, or
``
'csc'
``
"""
"""
ret
=
_CAPI_DGLHeteroGetRestrictFormat
(
self
,
etype
)
ret
=
_CAPI_DGLHeteroGetRestrictFormat
(
self
,
etype
)
return
ret
return
ret
def
request_format
(
self
,
sparse_format
,
etype
):
"""Create a sparse matrix representation in given format immediately.
Parameters
----------
etype : int
The edge/relation type.
sparse_format : str
``'coo'``, ``'csr'``, or ``'csc'``
"""
_CAPI_DGLHeteroRequestFormat
(
self
,
sparse_format
,
etype
)
def
to_format
(
self
,
restrict_format
):
def
to_format
(
self
,
restrict_format
):
"""Return a clone graph index but stored in the given sparse format.
"""Return a clone graph index but stored in the given sparse format.
...
@@ -969,8 +981,8 @@ class HeteroGraphIndex(ObjectBase):
...
@@ -969,8 +981,8 @@ class HeteroGraphIndex(ObjectBase):
Parameters
Parameters
----------
----------
restrict_format : str
ing
restrict_format : str
Desired restrict format ('any'
, 'coo', 'csr',
'csc').
Desired restrict format (
``
'any'
``, ``'coo'``, ``'csr'``, ``
'csc'
``
).
Returns
Returns
-------
-------
...
...
src/graph/heterograph_capi.cc
View file @
70dc2ee9
...
@@ -489,6 +489,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
...
@@ -489,6 +489,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
*
rv
=
hg
->
GetRelationGraph
(
etype
)
->
GetFormatInUse
();
*
rv
=
hg
->
GetRelationGraph
(
etype
)
->
GetFormatInUse
();
});
});
DGL_REGISTER_GLOBAL
(
"heterograph_index._CAPI_DGLHeteroRequestFormat"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
HeteroGraphRef
hg
=
args
[
0
];
const
std
::
string
sparse_format
=
args
[
1
];
dgl_type_t
etype
=
args
[
2
];
CHECK_LE
(
etype
,
hg
->
NumEdgeTypes
())
<<
"invalid edge type "
<<
etype
;
auto
bg
=
std
::
dynamic_pointer_cast
<
UnitGraph
>
(
hg
->
GetRelationGraph
(
etype
));
bg
->
GetFormat
(
ParseSparseFormat
(
sparse_format
));
});
DGL_REGISTER_GLOBAL
(
"heterograph_index._CAPI_DGLHeteroGetFormatGraph"
)
DGL_REGISTER_GLOBAL
(
"heterograph_index._CAPI_DGLHeteroGetFormatGraph"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
HeteroGraphRef
hg
=
args
[
0
];
HeteroGraphRef
hg
=
args
[
0
];
...
...
src/graph/unit_graph.h
View file @
70dc2ee9
...
@@ -248,6 +248,14 @@ class UnitGraph : public BaseHeteroGraph {
...
@@ -248,6 +248,14 @@ class UnitGraph : public BaseHeteroGraph {
return
ToStringSparseFormat
(
this
->
restrict_format_
);
return
ToStringSparseFormat
(
this
->
restrict_format_
);
}
}
/*!
* \brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
* \return A graph in the requested format.
*/
HeteroGraphPtr
GetFormat
(
SparseFormat
format
)
const
;
dgl_format_code_t
GetFormatInUse
()
const
override
;
dgl_format_code_t
GetFormatInUse
()
const
override
;
HeteroGraphPtr
GetGraphInFormat
(
SparseFormat
restrict_format
)
const
override
;
HeteroGraphPtr
GetGraphInFormat
(
SparseFormat
restrict_format
)
const
override
;
...
@@ -298,14 +306,6 @@ class UnitGraph : public BaseHeteroGraph {
...
@@ -298,14 +306,6 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return any existing format. */
/*! \return Return any existing format. */
HeteroGraphPtr
GetAny
()
const
;
HeteroGraphPtr
GetAny
()
const
;
/*!
* \brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
* \return A graph in the requested format.
*/
HeteroGraphPtr
GetFormat
(
SparseFormat
format
)
const
;
/*!
/*!
* \brief Determine which format to use with a preference.
* \brief Determine which format to use with a preference.
*
*
...
...
tests/compute/test_heterograph.py
View file @
70dc2ee9
...
@@ -1637,20 +1637,18 @@ def test_format():
...
@@ -1637,20 +1637,18 @@ def test_format():
g
=
dgl
.
graph
([(
0
,
0
),
(
1
,
1
),
(
0
,
1
),
(
2
,
0
)],
restrict_format
=
'coo'
)
g
=
dgl
.
graph
([(
0
,
0
),
(
1
,
1
),
(
0
,
1
),
(
2
,
0
)],
restrict_format
=
'coo'
)
assert
g
.
restrict_format
()
==
'coo'
assert
g
.
restrict_format
()
==
'coo'
assert
g
.
format_in_use
()
==
[
'coo'
]
assert
g
.
format_in_use
()
==
[
'coo'
]
try
:
try
:
spmat
=
g
.
adjacency_matrix
(
scipy_fmt
=
"csr"
)
spmat
=
g
.
adjacency_matrix
(
scipy_fmt
=
"csr"
)
except
:
except
:
print
(
'test passed, graph with restrict_format coo should not create csr matrix.'
)
print
(
'test passed, graph with restrict_format coo should not create csr matrix.'
)
else
:
else
:
assert
False
,
'cannot create csr when restrict_format is coo'
assert
False
,
'cannot create csr when restrict_format is coo'
g1
=
g
.
to_format
(
'any'
)
g1
=
g
.
to_format
(
'any'
)
assert
g1
.
restrict_format
()
==
'any'
assert
g1
.
restrict_format
()
==
'any'
spmat
=
g1
.
adjacency_matrix
(
scipy_fmt
=
'coo'
)
g1
.
request_format
(
'coo'
)
spmat
=
g1
.
adjacency_matrix
(
scipy_fmt
=
'csr'
)
g1
.
request_format
(
'csr'
)
spmat
=
g1
.
adjacency_matrix
(
transpose
=
True
,
scipy_fmt
=
'cs
r
'
)
g1
.
request_format
(
'cs
c
'
)
assert
len
(
g1
.
restrict_format
())
==
3
assert
len
(
g1
.
format_in_use
())
==
3
assert
g
.
restrict_format
()
==
'coo'
assert
g
.
restrict_format
()
==
'coo'
assert
g
.
format_in_use
()
==
[
'coo'
]
assert
g
.
format_in_use
()
==
[
'coo'
]
...
@@ -1664,15 +1662,13 @@ def test_format():
...
@@ -1664,15 +1662,13 @@ def test_format():
g
[
'follows'
].
srcdata
[
'h'
]
=
user_feat
g
[
'follows'
].
srcdata
[
'h'
]
=
user_feat
for
rel_type
in
[
'follows'
,
'plays'
,
'develops'
]:
for
rel_type
in
[
'follows'
,
'plays'
,
'develops'
]:
assert
g
.
restrict_format
(
rel_type
)
==
'csr'
assert
g
.
restrict_format
(
rel_type
)
==
'csr'
print
(
g
.
format_in_use
(
rel_type
),
g
.
restrict_format
(
rel_type
))
assert
g
.
format_in_use
(
rel_type
)
==
[
'csr'
]
assert
g
.
format_in_use
(
rel_type
)
==
[
'csr'
]
try
:
try
:
spmat
=
g
[
rel_type
].
adjacency_matrix
(
scipy_fmt
=
'coo'
)
g
[
rel_type
].
request_format
(
'coo'
)
except
:
except
:
print
(
'test passed, graph with restrict_format csr should not create coo matrix'
)
print
(
'test passed, graph with restrict_format csr should not create coo matrix'
)
else
:
else
:
assert
False
,
'cannot create coo when restrict_ormat is csr'
assert
False
,
'cannot create coo when restrict_
f
ormat is csr'
g1
=
g
.
to_format
(
'csc'
)
g1
=
g
.
to_format
(
'csc'
)
# test frame
# test frame
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment