Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f26ecb49
Commit
f26ecb49
authored
Sep 20, 2018
by
Gan Quan
Committed by
Minjie Wang
Sep 20, 2018
Browse files
Fix batching node-only graphs (#62)
* fixing batching with graphs with no edges * oops forgot test
parent
9b0a01db
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
1 deletion
+20
-1
python/dgl/batch.py
python/dgl/batch.py
+2
-1
tests/pytorch/test_graph_batch.py
tests/pytorch/test_graph_batch.py
+18
-0
No files found.
python/dgl/batch.py
View file @
f26ecb49
...
@@ -27,7 +27,8 @@ class BatchedDGLGraph(DGLGraph):
...
@@ -27,7 +27,8 @@ class BatchedDGLGraph(DGLGraph):
# in-order add relabeled edges
# in-order add relabeled edges
self
.
new_edge_list
=
[
np
.
array
(
g
.
edge_list
)
+
offset
self
.
new_edge_list
=
[
np
.
array
(
g
.
edge_list
)
+
offset
for
g
,
offset
in
zip
(
self
.
graph_list
,
self
.
node_offset
[:
-
1
])]
for
g
,
offset
in
zip
(
self
.
graph_list
,
self
.
node_offset
[:
-
1
])
if
len
(
g
.
edge_list
)
>
0
]
self
.
new_edges
=
np
.
concatenate
(
self
.
new_edge_list
)
self
.
new_edges
=
np
.
concatenate
(
self
.
new_edge_list
)
self
.
add_edges_from
(
self
.
new_edges
)
self
.
add_edges_from
(
self
.
new_edges
)
...
...
tests/pytorch/test_graph_batch.py
View file @
f26ecb49
...
@@ -141,8 +141,26 @@ def test_batched_edge_ordering():
...
@@ -141,8 +141,26 @@ def test_batched_edge_ordering():
r2
=
g1
.
get_e_repr
()[
g1
.
get_edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
g1
.
get_edge_id
(
4
,
5
)]
assert
torch
.
equal
(
r1
,
r2
)
assert
torch
.
equal
(
r1
,
r2
)
def
test_batch_no_edge
():
g1
=
dgl
.
DGLGraph
()
g1
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g1
.
add_edges_from
([(
4
,
5
),
(
4
,
3
),
(
2
,
3
),
(
2
,
1
),
(
0
,
1
)])
g1
.
edge_list
e1
=
torch
.
randn
(
5
,
10
)
g1
.
set_e_repr
(
e1
)
g2
=
dgl
.
DGLGraph
()
g2
.
add_nodes_from
([
0
,
1
,
2
,
3
,
4
,
5
])
g2
.
add_edges_from
([(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
5
,
4
),
(
4
,
3
),
(
5
,
0
)])
e2
=
torch
.
randn
(
6
,
10
)
g2
.
set_e_repr
(
e2
)
g3
=
dgl
.
DGLGraph
()
g3
.
add_nodes_from
([
0
])
# no edges
g
=
dgl
.
batch
([
g1
,
g3
,
g2
])
# should not throw an error
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_batch_unbatch
()
test_batch_unbatch
()
test_batched_edge_ordering
()
test_batched_edge_ordering
()
test_batch_sendrecv
()
test_batch_sendrecv
()
test_batch_propagate
()
test_batch_propagate
()
test_batch_no_edge
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment