Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
44db98c4
Commit
44db98c4
authored
Sep 19, 2018
by
Minjie Wang
Browse files
remove nx inheritence
parent
916d375b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
406 additions
and
80 deletions
+406
-80
python/dgl/graph.py
python/dgl/graph.py
+406
-80
No files found.
python/dgl/graph.py
View file @
44db98c4
...
@@ -25,9 +25,9 @@ class DGLGraph(object):
...
@@ -25,9 +25,9 @@ class DGLGraph(object):
----------
----------
graph_data : graph data
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
Data to initialize graph. Same as networkx's semantics.
node_frame :
dgl.frame.
Frame
node_frame : Frame
Ref
Node feature storage.
Node feature storage.
edge_frame :
dgl.frame.
Frame
edge_frame : Frame
Ref
Edge feature storage.
Edge feature storage.
attr : keyword arguments, optional
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
Attributes to add to graph as key=value pairs.
...
@@ -37,14 +37,7 @@ class DGLGraph(object):
...
@@ -37,14 +37,7 @@ class DGLGraph(object):
node_frame
=
None
,
node_frame
=
None
,
edge_frame
=
None
,
edge_frame
=
None
,
**
attr
):
**
attr
):
# TODO(minjie): maintaining node/edge list is costly when graph is large.
# TODO: keyword attr
#nx_init(self,
# self._add_node_callback,
# self._add_edge_callback,
# self._del_node_callback,
# self._del_edge_callback,
# graph_data,
# **attr)
# graph
# graph
self
.
_graph
=
GraphIndex
(
graph_data
)
self
.
_graph
=
GraphIndex
(
graph_data
)
# frame
# frame
...
@@ -59,10 +52,380 @@ class DGLGraph(object):
...
@@ -59,10 +52,380 @@ class DGLGraph(object):
self
.
_apply_node_func
=
(
None
,
None
)
self
.
_apply_node_func
=
(
None
,
None
)
self
.
_apply_edge_func
=
(
None
,
None
)
self
.
_apply_edge_func
=
(
None
,
None
)
def
add_nodes
(
self
,
num
,
reprs
=
None
):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
"""
self
.
_graph
.
add_nodes
(
num
)
#TODO(minjie): change frames
def
add_edge
(
self
,
u
,
v
,
repr
=
None
):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
repr : dict
Optional edge representation.
"""
self
.
_graph
.
add_edge
(
u
,
v
)
#TODO(minjie): change frames
def
add_edges
(
self
,
u
,
v
,
reprs
=
None
):
"""Add many edges.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
reprs : dict
Optional node representations.
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
self
.
_graph
.
add_edges
(
u
,
v
)
#TODO(minjie): change frames
def
clear
(
self
):
"""Clear the graph and its storage."""
self
.
_graph
.
clear
()
self
.
_node_frame
.
clear
()
self
.
_edge_frame
.
clear
()
self
.
_msg_graph
.
clear
()
self
.
_msg_frame
.
clear
()
def
number_of_nodes
(
self
):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return
self
.
_graph
.
number_of_nodes
()
def
number_of_edges
(
self
):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return
self
.
_graph
.
number_of_edges
()
def
has_node
(
self
,
vid
):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return
self
.
has_node
(
vid
)
def
has_nodes
(
self
,
vids
):
"""Return true if the nodes exist.
Parameters
----------
vid : list, tensor
The nodes
Returns
-------
tensor
0-1 array indicating existence
"""
vids
=
utils
.
toindex
(
vids
)
rst
=
self
.
_graph
.
has_nodes
(
vids
)
return
rst
.
tousertensor
()
def
has_edge
(
self
,
u
,
v
):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return
self
.
_graph
.
has_edge
(
u
,
v
)
def
has_edges
(
self
,
u
,
v
):
"""Return true if the edge exists.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
0-1 array indicating existence
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
rst
=
self
.
_graph
.
has_edges
(
u
,
v
)
return
rst
.
tousertensor
()
def
predecessors
(
self
,
v
,
radius
=
1
):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of predecessors
"""
return
self
.
_graph
.
predecessors
(
v
).
tousertensor
()
def
successors
(
self
,
v
,
radius
=
1
):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of successors
"""
return
self
.
_graph
.
successors
(
v
).
tousertensor
()
def
edge_id
(
self
,
u
,
v
):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return
self
.
_graph
.
edge_id
(
u
,
v
)
def
edge_ids
(
self
,
u
,
v
):
"""Return the edge ids.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
The edge id array.
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
rst
=
self
.
_graph
.
edge_ids
(
u
,
v
)
return
rst
.
tousertensor
()
def
in_edges
(
self
,
v
):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v
=
utils
.
toindex
(
v
)
src
,
dst
,
eid
=
self
.
_graph
.
in_edges
(
v
)
return
src
.
tousertensor
(),
dst
.
tousertensor
(),
eid
.
tousertensor
()
def
out_edges
(
self
,
v
):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v
=
utils
.
toindex
(
v
)
src
,
dst
,
eid
=
self
.
_graph
.
out_edges
(
v
)
return
src
.
tousertensor
(),
dst
.
tousertensor
(),
eid
.
tousertensor
()
def
edges
(
self
,
sorted
=
False
):
"""Return all the edges.
Parameters
----------
sorted : bool
True if the returned edges are sorted by their ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
src
,
dst
,
eid
=
self
.
_graph
.
edges
(
sorted
)
return
src
.
tousertensor
(),
dst
.
tousertensor
(),
eid
.
tousertensor
()
def
in_degree
(
self
,
v
):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return
self
.
_graph
.
in_degree
(
v
)
def
in_degrees
(
self
,
v
):
"""Return the in degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The in degree array.
"""
return
self
.
_graph
.
in_degrees
(
v
).
tousertensor
()
def
out_degree
(
self
,
v
):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return
self
.
_graph
.
out_degree
(
v
)
def
out_degrees
(
self
,
v
):
"""Return the out degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The out degree array.
"""
return
self
.
_graph
.
out_degrees
(
v
).
tousertensor
()
def
to_networkx
(
self
,
node_attrs
=
None
,
edge_attrs
=
None
):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
networkx.DiGraph
The nx graph
"""
nx_graph
=
self
.
_graph
.
to_networkx
()
#TODO: attributes
return
nx_graph
def
node_attr_schemes
(
self
):
def
node_attr_schemes
(
self
):
"""Return the node attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return
self
.
_node_frame
.
schemes
return
self
.
_node_frame
.
schemes
def
edge_attr_schemes
(
self
):
def
edge_attr_schemes
(
self
):
"""Return the edge attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return
self
.
_edge_frame
.
schemes
return
self
.
_edge_frame
.
schemes
def
set_n_repr
(
self
,
hu
,
u
=
ALL
):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
):
...
@@ -113,7 +476,12 @@ class DGLGraph(object):
...
@@ -113,7 +476,12 @@ class DGLGraph(object):
Parameters
Parameters
----------
----------
u : node, container or tensor
u : node, container or tensor
The node(s).
The node(s).
Returns
-------
dict
Representation dict
"""
"""
if
is_all
(
u
):
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
...
@@ -133,7 +501,12 @@ class DGLGraph(object):
...
@@ -133,7 +501,12 @@ class DGLGraph(object):
Parameters
Parameters
----------
----------
key : str
key : str
The attribute name.
The attribute name.
Returns
-------
Tensor
The popped representation
"""
"""
return
self
.
_node_frame
.
pop
(
key
)
return
self
.
_node_frame
.
pop
(
key
)
...
@@ -229,6 +602,11 @@ class DGLGraph(object):
...
@@ -229,6 +602,11 @@ class DGLGraph(object):
The source node(s).
The source node(s).
v : node, container or tensor
v : node, container or tensor
The destination node(s).
The destination node(s).
Returns
-------
dict
Representation dict
"""
"""
u_is_all
=
is_all
(
u
)
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
...
@@ -254,6 +632,11 @@ class DGLGraph(object):
...
@@ -254,6 +632,11 @@ class DGLGraph(object):
----------
----------
key : str
key : str
The attribute name.
The attribute name.
Returns
-------
Tensor
The popped representation
"""
"""
return
self
.
_edge_frame
.
pop
(
key
)
return
self
.
_edge_frame
.
pop
(
key
)
...
@@ -264,6 +647,11 @@ class DGLGraph(object):
...
@@ -264,6 +647,11 @@ class DGLGraph(object):
----------
----------
eid : int, container or tensor
eid : int, container or tensor
The edge id(s).
The edge id(s).
Returns
-------
dict
Representation dict
"""
"""
if
is_all
(
eid
):
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
...
@@ -368,6 +756,7 @@ class DGLGraph(object):
...
@@ -368,6 +756,7 @@ class DGLGraph(object):
new_repr
=
apply_node_func
(
self
.
get_n_repr
(
v
))
new_repr
=
apply_node_func
(
self
.
get_n_repr
(
v
))
self
.
set_n_repr
(
new_repr
,
v
)
self
.
set_n_repr
(
new_repr
,
v
)
else
:
else
:
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
v
):
if
is_all
(
v
):
v
=
self
.
nodes
()
v
=
self
.
nodes
()
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
...
@@ -441,6 +830,7 @@ class DGLGraph(object):
...
@@ -441,6 +830,7 @@ class DGLGraph(object):
self
.
_nonbatch_send
(
u
,
v
,
message_func
)
self
.
_nonbatch_send
(
u
,
v
,
message_func
)
def
_nonbatch_send
(
self
,
u
,
v
,
message_func
):
def
_nonbatch_send
(
self
,
u
,
v
,
message_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
else
:
...
@@ -505,6 +895,7 @@ class DGLGraph(object):
...
@@ -505,6 +895,7 @@ class DGLGraph(object):
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
else
:
...
@@ -587,6 +978,7 @@ class DGLGraph(object):
...
@@ -587,6 +978,7 @@ class DGLGraph(object):
self
.
apply_nodes
(
u
,
apply_node_func
,
batchable
)
self
.
apply_nodes
(
u
,
apply_node_func
,
batchable
)
def
_nonbatch_recv
(
self
,
u
,
reduce_func
):
def
_nonbatch_recv
(
self
,
u
,
reduce_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
):
if
is_all
(
u
):
u
=
list
(
range
(
0
,
self
.
number_of_nodes
()))
u
=
list
(
range
(
0
,
self
.
number_of_nodes
()))
else
:
else
:
...
@@ -916,75 +1308,9 @@ class DGLGraph(object):
...
@@ -916,75 +1308,9 @@ class DGLGraph(object):
self
.
_edge_frame
.
num_rows
,
self
.
_edge_frame
.
num_rows
,
reduce_func
)
reduce_func
)
def
draw
(
self
):
"""Plot the graph using dot."""
from
networkx.drawing.nx_agraph
import
graphviz_layout
pos
=
graphviz_layout
(
self
,
prog
=
'dot'
)
nx
.
draw
(
self
,
pos
,
with_labels
=
True
)
@
property
def
msg_graph
(
self
):
# TODO: dirty flag when mutated
if
self
.
_msg_graph
is
None
:
self
.
_msg_graph
=
CachedGraph
()
self
.
_msg_graph
.
add_nodes
(
self
.
number_of_nodes
())
return
self
.
_msg_graph
def
clear_messages
(
self
):
def
clear_messages
(
self
):
if
self
.
_msg_graph
is
not
None
:
self
.
_msg_graph
.
clear
()
self
.
_msg_graph
=
CachedGraph
()
self
.
_msg_frame
.
clear
()
self
.
_msg_graph
.
add_nodes
(
self
.
number_of_nodes
())
self
.
_msg_frame
.
clear
()
@
property
def
edge_list
(
self
):
"""Return edges in the addition order."""
return
self
.
_edge_list
def
get_edge_id
(
self
,
u
,
v
):
"""Return the continuous edge id(s) assigned.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
Returns
-------
eid : tensor
The tensor contains edge id(s).
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
return
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
def
_add_node_callback
(
self
,
node
):
#print('New node:', node)
self
.
_cached_graph
=
None
def
_del_node_callback
(
self
,
node
):
#print('Del node:', node)
raise
RuntimeError
(
'Node removal is not supported currently.'
)
node
=
utils
.
convert_to_id_tensor
(
node
)
self
.
_node_frame
.
delete_rows
(
node
)
self
.
_cached_graph
=
None
def
_add_edge_callback
(
self
,
u
,
v
):
#print('New edge:', u, v)
self
.
_edge_list
.
append
((
u
,
v
))
self
.
_cached_graph
=
None
def
_del_edge_callback
(
self
,
u
,
v
):
#print('Del edge:', u, v)
raise
RuntimeError
(
'Edge removal is not supported currently.'
)
u
=
utils
.
convert_to_id_tensor
(
u
)
v
=
utils
.
convert_to_id_tensor
(
v
)
eid
=
self
.
get_edge_id
(
u
,
v
)
self
.
_edge_frame
.
delete_rows
(
eid
)
self
.
_cached_graph
=
None
def
_get_repr
(
attr_dict
):
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment