Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
882e2a7b
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "6a363378d50c4c1d6c01b1942b170050286e2923"
Commit
882e2a7b
authored
Sep 24, 2018
by
Minjie Wang
Browse files
Graph batching. Support convert nx graph attrs
parent
314a75f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
7 deletions
+29
-7
python/dgl/batch.py
python/dgl/batch.py
+4
-4
python/dgl/data/tree.py
python/dgl/data/tree.py
+1
-1
python/dgl/graph.py
python/dgl/graph.py
+24
-2
No files found.
python/dgl/batch.py
View file @
882e2a7b
...
...
@@ -31,11 +31,11 @@ class BatchedDGLGraph(DGLGraph):
# NOTE: following code will materialize the columns of the input graphs.
batched_node_frame
=
FrameRef
()
for
gr
in
graph_list
:
cols
=
{
gr
.
_node_frame
[
key
]
for
key
in
node_attrs
}
cols
=
{
key
:
gr
.
_node_frame
[
key
]
for
key
in
node_attrs
}
batched_node_frame
.
append
(
cols
)
batched_edge_frame
=
FrameRef
()
for
gr
in
graph_list
:
cols
=
{
gr
.
_edge_frame
[
key
]
for
key
in
edge_attrs
}
cols
=
{
key
:
gr
.
_edge_frame
[
key
]
for
key
in
edge_attrs
}
batched_edge_frame
.
append
(
cols
)
super
(
BatchedDGLGraph
,
self
).
__init__
(
graph_data
=
batched_index
,
...
...
@@ -169,12 +169,12 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
node_attrs
=
[]
elif
is_all
(
node_attrs
):
node_attrs
=
graph_list
[
0
].
node_attr_schemes
()
elif
if
isinstance
(
node_attrs
,
str
):
elif
isinstance
(
node_attrs
,
str
):
node_attrs
=
[
node_attrs
]
if
edge_attrs
is
None
:
edge_attrs
=
[]
elif
is_all
(
edge_attrs
):
edge_attrs
=
graph_list
[
0
].
edge_attr_schemes
()
elif
if
isinstance
(
edge_attrs
,
str
):
elif
isinstance
(
edge_attrs
,
str
):
edge_attrs
=
[
edge_attrs
]
return
BatchedDGLGraph
(
graph_list
,
node_attrs
,
edge_attrs
)
python/dgl/data/tree.py
View file @
882e2a7b
...
...
@@ -67,7 +67,7 @@ class SST(object):
g
.
add_node
(
0
,
x
=
SST
.
PAD_WORD
,
y
=
int
(
root
.
label
()))
_rec_build
(
0
,
root
)
ret
=
DGLGraph
()
ret
.
from_networkx
(
g
)
ret
.
from_networkx
(
g
,
node_attrs
=
[
'x'
,
'y'
]
)
return
ret
def
__getitem__
(
self
,
idx
):
...
...
python/dgl/graph.py
View file @
882e2a7b
...
...
@@ -439,12 +439,34 @@ class DGLGraph(object):
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self
.
clear
()
self
.
_graph
.
from_networkx
(
nx_graph
)
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
#TODO: attributes
pass
def
_batcher
(
lst
):
if
isinstance
(
lst
[
0
],
Tensor
):
return
F
.
pack
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
])
else
:
return
F
.
tensor
(
lst
)
if
node_attrs
is
not
None
:
attr_dict
=
{
attr
:
[]
for
attr
in
node_attrs
}
for
nid
in
range
(
self
.
number_of_nodes
()):
for
attr
in
node_attrs
:
attr_dict
[
attr
].
append
(
nx_graph
.
nodes
[
nid
][
attr
])
for
attr
in
node_attrs
:
self
.
_node_frame
[
attr
]
=
_batcher
(
attr_dict
[
attr
])
if
edge_attrs
is
not
None
:
attr_dict
=
{
attr
:
[]
for
attr
in
edge_attrs
}
src
,
dst
,
_
=
self
.
_graph
.
edges
()
for
u
,
v
in
zip
(
src
.
tolist
(),
dst
.
tolist
()):
for
attr
in
edge_attrs
:
attr_dict
[
attr
].
append
(
nx_graph
.
edges
[
u
,
v
][
attr
])
for
attr
in
edge_attrs
:
self
.
_edge_frame
[
attr
]
=
_batcher
(
attr_dict
[
attr
])
def
node_attr_schemes
(
self
):
"""Return the node attribute schemes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment