Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
596ca471
Commit
596ca471
authored
Oct 06, 2018
by
GaiYu0
Browse files
Merge branch 'cpp' of
https://github.com/jermainewang/dgl
into line-graph
parents
52ed6a45
72f63455
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
13 deletions
+11
-13
tests/pytorch/test_basics_anonymous.py
tests/pytorch/test_basics_anonymous.py
+0
-0
tests/pytorch/test_batched_graph.py
tests/pytorch/test_batched_graph.py
+0
-0
tests/pytorch/test_subgraph.py
tests/pytorch/test_subgraph.py
+11
-13
No files found.
tests/pytorch/test_ba
tching
_anonymous.py
→
tests/pytorch/test_ba
sics
_anonymous.py
View file @
596ca471
File moved
tests/pytorch/test_
graph_batc
h.py
→
tests/pytorch/test_
batched_grap
h.py
View file @
596ca471
File moved
tests/pytorch/test_subgraph.py
View file @
596ca471
...
...
@@ -5,13 +5,9 @@ from dgl.graph import DGLGraph
D
=
5
def
check_eq
(
a
,
b
):
return
a
.
shape
==
b
.
shape
and
np
.
allclose
(
a
.
numpy
(),
b
.
numpy
())
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
...
...
@@ -29,17 +25,19 @@ def test_basics():
h
=
g
.
get_n_repr
()[
'h'
]
l
=
g
.
get_e_repr
()[
'l'
]
nid
=
[
0
,
2
,
3
,
6
,
7
,
9
]
eid
=
[
2
,
3
,
4
,
5
,
10
,
11
,
12
,
13
,
16
]
sg
=
g
.
subgraph
(
nid
)
eid
=
{
2
,
3
,
4
,
5
,
10
,
11
,
12
,
13
,
16
}
assert
set
(
sg
.
parent_eid
.
numpy
())
==
eid
eid
=
sg
.
parent_eid
# the subgraph is empty initially
assert
len
(
sg
.
get_n_repr
())
==
0
assert
len
(
sg
.
get_e_repr
())
==
0
# the data is copied after explict copy from
sg
.
copy_from
(
g
)
sg
.
copy_from
_parent
()
assert
len
(
sg
.
get_n_repr
())
==
1
assert
len
(
sg
.
get_e_repr
())
==
1
sh
=
sg
.
get_n_repr
()[
'h'
]
assert
check_eq
(
h
[
nid
],
sh
)
assert
th
.
allclose
(
h
[
nid
],
sh
)
'''
s, d, eid
0, 1, 0
...
...
@@ -60,11 +58,11 @@ def test_basics():
8, 9, 15 3
9, 0, 16 1
'''
assert
check_eq
(
l
[
eid
],
sg
.
get_e_repr
()[
'l'
])
assert
th
.
allclose
(
l
[
eid
],
sg
.
get_e_repr
()[
'l'
])
# update the node/edge features on the subgraph should NOT
# reflect to the parent graph.
sg
.
set_n_repr
({
'h'
:
th
.
zeros
((
6
,
D
))})
assert
check_eq
(
h
,
g
.
get_n_repr
()[
'h'
])
assert
th
.
allclose
(
h
,
g
.
get_n_repr
()[
'h'
])
def
test_merge
():
g
=
generate_graph
()
...
...
@@ -85,10 +83,10 @@ def test_merge():
h
=
g
.
get_n_repr
()[
'h'
][:,
0
]
l
=
g
.
get_e_repr
()[
'l'
][:,
0
]
assert
check_eq
(
h
,
th
.
tensor
([
3.
,
0.
,
3.
,
3.
,
2.
,
0.
,
1.
,
1.
,
0.
,
1.
]))
assert
check_eq
(
l
,
assert
th
.
allclose
(
h
,
th
.
tensor
([
3.
,
0.
,
3.
,
3.
,
2.
,
0.
,
1.
,
1.
,
0.
,
1.
]))
assert
th
.
allclose
(
l
,
th
.
tensor
([
0.
,
0.
,
1.
,
1.
,
1.
,
1.
,
0.
,
0.
,
0.
,
3.
,
1.
,
4.
,
1.
,
4.
,
0.
,
3.
,
1.
]))
if
__name__
==
'__main__'
:
test_basics
()
test_merge
()
#
test_merge()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment