Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4673b96f
Unverified
Commit
4673b96f
authored
Aug 09, 2018
by
Minjie Wang
Committed by
GitHub
Aug 09, 2018
Browse files
fix set_n_repr grad problem (#38)
parent
0a78dbe1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
10 deletions
+49
-10
python/dgl/graph.py
python/dgl/graph.py
+6
-6
tests/pytorch/test_batching.py
tests/pytorch/test_batching.py
+21
-2
tests/pytorch/test_batching_anonymous.py
tests/pytorch/test_batching_anonymous.py
+22
-2
No files found.
python/dgl/graph.py
View file @
4673b96f
...
...
@@ -134,9 +134,9 @@ class DGLGraph(DiGraph):
else
:
if
isinstance
(
hu
,
dict
):
for
key
,
val
in
hu
.
items
():
self
.
_node_frame
[
key
]
[
u
]
=
val
self
.
_node_frame
[
key
]
=
F
.
scatter_row
(
self
.
_node_frame
[
key
],
u
,
val
)
else
:
self
.
_node_frame
[
__REPR__
]
[
u
]
=
hu
self
.
_node_frame
[
__REPR__
]
=
F
.
scatter_row
(
self
.
_node_frame
[
__REPR__
],
u
,
hu
)
def
get_n_repr
(
self
,
u
=
ALL
):
"""Get node(s) representation.
...
...
@@ -214,9 +214,9 @@ class DGLGraph(DiGraph):
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
if
isinstance
(
h_uv
,
dict
):
for
key
,
val
in
h_uv
.
items
():
self
.
_edge_frame
[
key
]
[
eid
]
=
val
self
.
_edge_frame
[
key
]
=
F
.
scatter_row
(
self
.
_edge_frame
[
key
],
eid
,
val
)
else
:
self
.
_edge_frame
[
__REPR__
]
[
eid
]
=
h_uv
self
.
_edge_frame
[
__REPR__
]
=
F
.
scatter_row
(
self
.
_edge_frame
[
__REPR__
],
eid
,
h_uv
)
def
set_e_repr_by_id
(
self
,
h_uv
,
eid
=
ALL
):
"""Set edge(s) representation by edge id.
...
...
@@ -249,9 +249,9 @@ class DGLGraph(DiGraph):
else
:
if
isinstance
(
h_uv
,
dict
):
for
key
,
val
in
h_uv
.
items
():
self
.
_edge_frame
[
key
]
[
eid
]
=
val
self
.
_edge_frame
[
key
]
=
F
.
scatter_row
(
self
.
_edge_frame
[
key
],
eid
,
val
)
else
:
self
.
_edge_frame
[
__REPR__
]
[
eid
]
=
h_uv
self
.
_edge_frame
[
__REPR__
]
=
F
.
scatter_row
(
self
.
_edge_frame
[
__REPR__
],
eid
,
h_uv
)
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
"""Get node(s) representation.
...
...
tests/pytorch/test_batching.py
View file @
4673b96f
import
torch
as
th
from
torch.autograd
import
Variable
import
numpy
as
np
from
dgl.graph
import
DGLGraph
D
=
5
reduce_msg_shapes
=
set
()
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
th
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
message_func
(
src
,
edge
):
assert
len
(
src
[
'h'
].
shape
)
==
2
assert
src
[
'h'
].
shape
[
1
]
==
D
...
...
@@ -20,7 +26,7 @@ def update_func(node, accum):
assert
node
[
'h'
].
shape
==
accum
.
shape
return
{
'h'
:
node
[
'h'
]
+
accum
}
def
generate_graph
():
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
...
...
@@ -30,7 +36,7 @@ def generate_graph():
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
col
=
th
.
randn
(
10
,
D
)
col
=
Variable
(
th
.
randn
(
10
,
D
)
,
requires_grad
=
grad
)
g
.
set_n_repr
({
'h'
:
col
})
return
g
...
...
@@ -112,6 +118,18 @@ def test_batch_setter_getter():
v
=
th
.
tensor
([
3
,
4
,
5
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
)[
'l'
])
==
[
1.
,
1.
,
1.
]
def
test_batch_setter_autograd
():
g
=
generate_graph
(
grad
=
True
)
h1
=
g
.
get_n_repr
()[
'h'
]
# partial set
v
=
th
.
tensor
([
1
,
2
,
8
])
hh
=
Variable
(
th
.
zeros
((
len
(
v
),
D
)),
requires_grad
=
True
)
g
.
set_n_repr
({
'h'
:
hh
},
v
)
h2
=
g
.
get_n_repr
()[
'h'
]
h2
.
backward
(
th
.
ones
((
10
,
D
))
*
2
)
check_eq
(
h1
.
grad
[:,
0
],
th
.
tensor
([
2.
,
0.
,
0.
,
2.
,
2.
,
2.
,
2.
,
2.
,
0.
,
2.
]))
check_eq
(
hh
.
grad
[:,
0
],
th
.
tensor
([
2.
,
2.
,
2.
]))
def
test_batch_send
():
g
=
generate_graph
()
def
_fmsg
(
src
,
edge
):
...
...
@@ -180,6 +198,7 @@ def test_update_routines():
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_recv
()
test_update_routines
()
tests/pytorch/test_batching_anonymous.py
View file @
4673b96f
import
torch
as
th
from
torch.autograd
import
Variable
import
numpy
as
np
from
dgl.graph
import
DGLGraph
,
__REPR__
D
=
32
reduce_msg_shapes
=
set
()
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
th
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
message_func
(
hu
,
e_uv
):
assert
len
(
hu
.
shape
)
==
2
assert
hu
.
shape
[
1
]
==
D
...
...
@@ -19,7 +25,7 @@ def update_func(hv, accum):
assert
hv
.
shape
==
accum
.
shape
return
hv
+
accum
def
generate_graph
():
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
...
...
@@ -29,7 +35,7 @@ def generate_graph():
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
col
=
th
.
randn
(
10
,
D
)
col
=
Variable
(
th
.
randn
(
10
,
D
)
,
requires_grad
=
grad
)
g
.
set_n_repr
(
col
)
return
g
...
...
@@ -111,6 +117,18 @@ def test_batch_setter_getter():
v
=
th
.
tensor
([
3
,
4
,
5
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
1.
]
def
test_batch_setter_autograd
():
g
=
generate_graph
(
grad
=
True
)
h1
=
g
.
get_n_repr
()
# partial set
v
=
th
.
tensor
([
1
,
2
,
8
])
hh
=
Variable
(
th
.
zeros
((
len
(
v
),
D
)),
requires_grad
=
True
)
g
.
set_n_repr
(
hh
,
v
)
h2
=
g
.
get_n_repr
()
h2
.
backward
(
th
.
ones
((
10
,
D
))
*
2
)
check_eq
(
h1
.
grad
[:,
0
],
th
.
tensor
([
2.
,
0.
,
0.
,
2.
,
2.
,
2.
,
2.
,
2.
,
0.
,
2.
]))
check_eq
(
hh
.
grad
[:,
0
],
th
.
tensor
([
2.
,
2.
,
2.
]))
def
test_batch_send
():
g
=
generate_graph
()
def
_fmsg
(
hu
,
edge
):
...
...
@@ -178,6 +196,8 @@ def test_update_routines():
reduce_msg_shapes
.
clear
()
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_recv
()
test_update_routines
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment