Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fa0ee46a
Unverified
Commit
fa0ee46a
authored
Dec 06, 2019
by
Zihao Ye
Committed by
GitHub
Dec 06, 2019
Browse files
[hotfix] node id validity check (#1073)
* fix * improve * fix lint * upd * fix * upd
parent
bc4f4352
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
9 deletions
+84
-9
python/dgl/convert.py
python/dgl/convert.py
+52
-9
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+32
-0
No files found.
python/dgl/convert.py
View file @
fa0ee46a
...
...
@@ -21,7 +21,7 @@ __all__ = [
'to_networkx'
,
]
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
card
=
None
,
**
kwargs
):
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
card
=
None
,
validate
=
False
,
**
kwargs
):
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...
...
@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
...
...
@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
['follows']
>>> g.canonical_etypes
[('user', 'follows', 'user')]
Check if node ids are within cardinality
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})
"""
if
card
is
not
None
:
urange
,
vrange
=
card
,
card
...
...
@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
u
,
v
=
data
return
create_from_edges
(
u
,
v
,
ntype
,
etype
,
ntype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
ntype
,
etype
,
ntype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
list
):
return
create_from_edge_list
(
data
,
ntype
,
etype
,
ntype
,
urange
,
vrange
)
return
create_from_edge_list
(
data
,
ntype
,
etype
,
ntype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
return
create_from_scipy
(
data
,
ntype
,
etype
,
ntype
)
elif
isinstance
(
data
,
nx
.
Graph
):
...
...
@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
else
:
raise
DGLError
(
'Unsupported graph data type:'
,
type
(
data
))
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
card
=
None
,
**
kwargs
):
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
card
=
None
,
validate
=
False
,
**
kwargs
):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
...
...
@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
...
...
@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
4
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
metagraph=[('_U', '_V')])
"""
if
utype
==
vtype
:
raise
DGLError
(
'utype should not be equal to vtype. Use ``dgl.graph`` instead.'
)
...
...
@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
u
,
v
=
data
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
list
):
return
create_from_edge_list
(
data
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edge_list
(
data
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
return
create_from_scipy
(
data
,
utype
,
etype
,
vtype
)
elif
isinstance
(
data
,
nx
.
Graph
):
...
...
@@ -667,7 +695,7 @@ def to_homo(G):
# Internal APIs
############################################################
def
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
):
def
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
,
validate
=
False
):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
...
...
@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
...
...
@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
if
validate
:
if
urange
is
not
None
and
urange
<=
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
))):
raise
DGLError
(
'Invalid node id {} (should be less than cardinality {}).'
.
format
(
urange
,
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
)))))
if
vrange
is
not
None
and
vrange
<=
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
))):
raise
DGLError
(
'Invalid node id {} (should be less than cardinality {}).'
.
format
(
vrange
,
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
)))))
urange
=
urange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
)))
+
1
)
vrange
=
vrange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
)))
+
1
)
if
utype
==
vtype
:
...
...
@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
else
:
return
DGLHeteroGraph
(
hgidx
,
[
utype
,
vtype
],
[
etype
])
def
create_from_edge_list
(
elist
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
):
def
create_from_edge_list
(
elist
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
,
validate
=
False
):
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
...
...
@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
...
...
@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
u
,
v
=
zip
(
*
elist
)
u
=
list
(
u
)
v
=
list
(
v
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
def
create_from_scipy
(
spmat
,
utype
,
etype
,
vtype
,
with_edge_id
=
False
):
"""Internal function to create a heterograph from a scipy sparse matrix with types.
...
...
@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
...
...
tests/compute/test_heterograph.py
View file @
fa0ee46a
...
...
@@ -6,6 +6,8 @@ import scipy.sparse as ssp
import
itertools
import
backend
as
F
import
networkx
as
nx
from
dgl
import
DGLError
def
create_test_heterograph
():
# test heterograph from the docstring, plus a user -- wishes -- game relation
...
...
@@ -93,6 +95,36 @@ def test_create():
assert
g
.
number_of_nodes
(
'l1'
)
==
3
assert
g
.
number_of_nodes
(
'l2'
)
==
4
# test if validate flag works
# homo graph
fail
=
False
try
:
g
=
dgl
.
graph
(
([
0
,
0
,
0
,
1
,
1
,
2
],
[
0
,
1
,
2
,
0
,
1
,
2
]),
card
=
2
,
validate
=
True
)
except
DGLError
:
fail
=
True
finally
:
assert
fail
,
"should catch a DGLError because node ID is out of bound."
# bipartite graph
def
_test_validate_bipartite
(
card
):
fail
=
False
try
:
g
=
dgl
.
bipartite
(
([
0
,
0
,
1
,
1
,
2
],
[
1
,
1
,
2
,
2
,
3
]),
card
=
card
,
validate
=
True
)
except
DGLError
:
fail
=
True
finally
:
assert
fail
,
"should catch a DGLError because node ID is out of bound."
_test_validate_bipartite
((
3
,
3
))
_test_validate_bipartite
((
2
,
4
))
def
test_query
():
g
=
create_test_heterograph
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment