Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fa0ee46a
Unverified
Commit
fa0ee46a
authored
Dec 06, 2019
by
Zihao Ye
Committed by
GitHub
Dec 06, 2019
Browse files
[hotfix] node id validity check (#1073)
* fix * improve * fix lint * upd * fix * upd
parent
bc4f4352
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
9 deletions
+84
-9
python/dgl/convert.py
python/dgl/convert.py
+52
-9
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+32
-0
No files found.
python/dgl/convert.py
View file @
fa0ee46a
...
@@ -21,7 +21,7 @@ __all__ = [
...
@@ -21,7 +21,7 @@ __all__ = [
'to_networkx'
,
'to_networkx'
,
]
]
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
card
=
None
,
**
kwargs
):
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
card
=
None
,
validate
=
False
,
**
kwargs
):
"""Create a graph with one type of nodes and edges.
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...
@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
...
@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
card : int, optional
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None)
the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
graph. It can consist of:
...
@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
...
@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
['follows']
['follows']
>>> g.canonical_etypes
>>> g.canonical_etypes
[('user', 'follows', 'user')]
[('user', 'follows', 'user')]
Check if node ids are within cardinality
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})
"""
"""
if
card
is
not
None
:
if
card
is
not
None
:
urange
,
vrange
=
card
,
card
urange
,
vrange
=
card
,
card
...
@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
...
@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
urange
,
vrange
=
None
,
None
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
if
isinstance
(
data
,
tuple
):
u
,
v
=
data
u
,
v
=
data
return
create_from_edges
(
u
,
v
,
ntype
,
etype
,
ntype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
ntype
,
etype
,
ntype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
list
):
elif
isinstance
(
data
,
list
):
return
create_from_edge_list
(
data
,
ntype
,
etype
,
ntype
,
urange
,
vrange
)
return
create_from_edge_list
(
data
,
ntype
,
etype
,
ntype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
return
create_from_scipy
(
data
,
ntype
,
etype
,
ntype
)
return
create_from_scipy
(
data
,
ntype
,
etype
,
ntype
)
elif
isinstance
(
data
,
nx
.
Graph
):
elif
isinstance
(
data
,
nx
.
Graph
):
...
@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
...
@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
else
:
else
:
raise
DGLError
(
'Unsupported graph data type:'
,
type
(
data
))
raise
DGLError
(
'Unsupported graph data type:'
,
type
(
data
))
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
card
=
None
,
**
kwargs
):
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
card
=
None
,
validate
=
False
,
**
kwargs
):
"""Create a bipartite graph.
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
The result graph is directed and edges must be from ``utype`` nodes
...
@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
...
@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
card : pair of int, optional
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
Cardinality (number of nodes in the source and destination group). If None,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
graph. It can consist of:
...
@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
...
@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
4
4
>>> g.edges()
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
metagraph=[('_U', '_V')])
"""
"""
if
utype
==
vtype
:
if
utype
==
vtype
:
raise
DGLError
(
'utype should not be equal to vtype. Use ``dgl.graph`` instead.'
)
raise
DGLError
(
'utype should not be equal to vtype. Use ``dgl.graph`` instead.'
)
...
@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
...
@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
urange
,
vrange
=
None
,
None
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
if
isinstance
(
data
,
tuple
):
u
,
v
=
data
u
,
v
=
data
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
list
):
elif
isinstance
(
data
,
list
):
return
create_from_edge_list
(
data
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edge_list
(
data
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
elif
isinstance
(
data
,
sp
.
sparse
.
spmatrix
):
return
create_from_scipy
(
data
,
utype
,
etype
,
vtype
)
return
create_from_scipy
(
data
,
utype
,
etype
,
vtype
)
elif
isinstance
(
data
,
nx
.
Graph
):
elif
isinstance
(
data
,
nx
.
Graph
):
...
@@ -667,7 +695,7 @@ def to_homo(G):
...
@@ -667,7 +695,7 @@ def to_homo(G):
# Internal APIs
# Internal APIs
############################################################
############################################################
def
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
):
def
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
,
validate
=
False
):
"""Internal function to create a graph from incident nodes with types.
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
utype could be equal to vtype
...
@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
...
@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
vrange : int, optional
The destination node ID range. If None, the value is the
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
Returns
-------
-------
...
@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
...
@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
"""
"""
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
if
validate
:
if
urange
is
not
None
and
urange
<=
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
))):
raise
DGLError
(
'Invalid node id {} (should be less than cardinality {}).'
.
format
(
urange
,
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
)))))
if
vrange
is
not
None
and
vrange
<=
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
))):
raise
DGLError
(
'Invalid node id {} (should be less than cardinality {}).'
.
format
(
vrange
,
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
)))))
urange
=
urange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
)))
+
1
)
urange
=
urange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
u
.
tousertensor
(),
dim
=
0
)))
+
1
)
vrange
=
vrange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
)))
+
1
)
vrange
=
vrange
or
(
int
(
F
.
asnumpy
(
F
.
max
(
v
.
tousertensor
(),
dim
=
0
)))
+
1
)
if
utype
==
vtype
:
if
utype
==
vtype
:
...
@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
...
@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
else
:
else
:
return
DGLHeteroGraph
(
hgidx
,
[
utype
,
vtype
],
[
etype
])
return
DGLHeteroGraph
(
hgidx
,
[
utype
,
vtype
],
[
etype
])
def
create_from_edge_list
(
elist
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
):
def
create_from_edge_list
(
elist
,
utype
,
etype
,
vtype
,
urange
=
None
,
vrange
=
None
,
validate
=
False
):
"""Internal function to create a heterograph from a list of edge tuples with types.
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
utype could be equal to vtype
...
@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
...
@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
vrange : int, optional
The destination node ID range. If None, the value is the
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
Returns
-------
-------
...
@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
...
@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
u
,
v
=
zip
(
*
elist
)
u
,
v
=
zip
(
*
elist
)
u
=
list
(
u
)
u
=
list
(
u
)
v
=
list
(
v
)
v
=
list
(
v
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
return
create_from_edges
(
u
,
v
,
utype
,
etype
,
vtype
,
urange
,
vrange
,
validate
)
def
create_from_scipy
(
spmat
,
utype
,
etype
,
vtype
,
with_edge_id
=
False
):
def
create_from_scipy
(
spmat
,
utype
,
etype
,
vtype
,
with_edge_id
=
False
):
"""Internal function to create a heterograph from a scipy sparse matrix with types.
"""Internal function to create a heterograph from a scipy sparse matrix with types.
...
@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
...
@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
If True, the entries in the sparse matrix are treated as edge IDs.
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
(source, destination) order.
validate : bool, optional
If True, checks if node IDs are within range.
Returns
Returns
-------
-------
...
...
tests/compute/test_heterograph.py
View file @
fa0ee46a
...
@@ -6,6 +6,8 @@ import scipy.sparse as ssp
...
@@ -6,6 +6,8 @@ import scipy.sparse as ssp
import
itertools
import
itertools
import
backend
as
F
import
backend
as
F
import
networkx
as
nx
import
networkx
as
nx
from
dgl
import
DGLError
def
create_test_heterograph
():
def
create_test_heterograph
():
# test heterograph from the docstring, plus a user -- wishes -- game relation
# test heterograph from the docstring, plus a user -- wishes -- game relation
...
@@ -93,6 +95,36 @@ def test_create():
...
@@ -93,6 +95,36 @@ def test_create():
assert
g
.
number_of_nodes
(
'l1'
)
==
3
assert
g
.
number_of_nodes
(
'l1'
)
==
3
assert
g
.
number_of_nodes
(
'l2'
)
==
4
assert
g
.
number_of_nodes
(
'l2'
)
==
4
# test if validate flag works
# homo graph
fail
=
False
try
:
g
=
dgl
.
graph
(
([
0
,
0
,
0
,
1
,
1
,
2
],
[
0
,
1
,
2
,
0
,
1
,
2
]),
card
=
2
,
validate
=
True
)
except
DGLError
:
fail
=
True
finally
:
assert
fail
,
"should catch a DGLError because node ID is out of bound."
# bipartite graph
def
_test_validate_bipartite
(
card
):
fail
=
False
try
:
g
=
dgl
.
bipartite
(
([
0
,
0
,
1
,
1
,
2
],
[
1
,
1
,
2
,
2
,
3
]),
card
=
card
,
validate
=
True
)
except
DGLError
:
fail
=
True
finally
:
assert
fail
,
"should catch a DGLError because node ID is out of bound."
_test_validate_bipartite
((
3
,
3
))
_test_validate_bipartite
((
2
,
4
))
def
test_query
():
def
test_query
():
g
=
create_test_heterograph
()
g
=
create_test_heterograph
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment