Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
882e2a7b
Commit
882e2a7b
authored
Sep 24, 2018
by
Minjie Wang
Browse files
Graph batching. Support convert nx graph attrs
parent
314a75f3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
7 deletions
+29
-7
python/dgl/batch.py
python/dgl/batch.py
+4
-4
python/dgl/data/tree.py
python/dgl/data/tree.py
+1
-1
python/dgl/graph.py
python/dgl/graph.py
+24
-2
No files found.
python/dgl/batch.py
View file @
882e2a7b
...
...
@@ -31,11 +31,11 @@ class BatchedDGLGraph(DGLGraph):
# NOTE: following code will materialize the columns of the input graphs.
batched_node_frame
=
FrameRef
()
for
gr
in
graph_list
:
cols
=
{
gr
.
_node_frame
[
key
]
for
key
in
node_attrs
}
cols
=
{
key
:
gr
.
_node_frame
[
key
]
for
key
in
node_attrs
}
batched_node_frame
.
append
(
cols
)
batched_edge_frame
=
FrameRef
()
for
gr
in
graph_list
:
cols
=
{
gr
.
_edge_frame
[
key
]
for
key
in
edge_attrs
}
cols
=
{
key
:
gr
.
_edge_frame
[
key
]
for
key
in
edge_attrs
}
batched_edge_frame
.
append
(
cols
)
super
(
BatchedDGLGraph
,
self
).
__init__
(
graph_data
=
batched_index
,
...
...
@@ -169,12 +169,12 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
node_attrs
=
[]
elif
is_all
(
node_attrs
):
node_attrs
=
graph_list
[
0
].
node_attr_schemes
()
elif
if
isinstance
(
node_attrs
,
str
):
elif
isinstance
(
node_attrs
,
str
):
node_attrs
=
[
node_attrs
]
if
edge_attrs
is
None
:
edge_attrs
=
[]
elif
is_all
(
edge_attrs
):
edge_attrs
=
graph_list
[
0
].
edge_attr_schemes
()
elif
if
isinstance
(
edge_attrs
,
str
):
elif
isinstance
(
edge_attrs
,
str
):
edge_attrs
=
[
edge_attrs
]
return
BatchedDGLGraph
(
graph_list
,
node_attrs
,
edge_attrs
)
python/dgl/data/tree.py
View file @
882e2a7b
...
...
@@ -67,7 +67,7 @@ class SST(object):
g
.
add_node
(
0
,
x
=
SST
.
PAD_WORD
,
y
=
int
(
root
.
label
()))
_rec_build
(
0
,
root
)
ret
=
DGLGraph
()
ret
.
from_networkx
(
g
)
ret
.
from_networkx
(
g
,
node_attrs
=
[
'x'
,
'y'
]
)
return
ret
def
__getitem__
(
self
,
idx
):
...
...
python/dgl/graph.py
View file @
882e2a7b
...
...
@@ -439,12 +439,34 @@ class DGLGraph(object):
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self
.
clear
()
self
.
_graph
.
from_networkx
(
nx_graph
)
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
#TODO: attributes
pass
def
_batcher
(
lst
):
if
isinstance
(
lst
[
0
],
Tensor
):
return
F
.
pack
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
])
else
:
return
F
.
tensor
(
lst
)
if
node_attrs
is
not
None
:
attr_dict
=
{
attr
:
[]
for
attr
in
node_attrs
}
for
nid
in
range
(
self
.
number_of_nodes
()):
for
attr
in
node_attrs
:
attr_dict
[
attr
].
append
(
nx_graph
.
nodes
[
nid
][
attr
])
for
attr
in
node_attrs
:
self
.
_node_frame
[
attr
]
=
_batcher
(
attr_dict
[
attr
])
if
edge_attrs
is
not
None
:
attr_dict
=
{
attr
:
[]
for
attr
in
edge_attrs
}
src
,
dst
,
_
=
self
.
_graph
.
edges
()
for
u
,
v
in
zip
(
src
.
tolist
(),
dst
.
tolist
()):
for
attr
in
edge_attrs
:
attr_dict
[
attr
].
append
(
nx_graph
.
edges
[
u
,
v
][
attr
])
for
attr
in
edge_attrs
:
self
.
_edge_frame
[
attr
]
=
_batcher
(
attr_dict
[
attr
])
def
node_attr_schemes
(
self
):
"""Return the node attribute schemes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment