Unverified Commit 4673b96f authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

fix set_n_repr grad problem (#38)

parent 0a78dbe1
......@@ -134,9 +134,9 @@ class DGLGraph(DiGraph):
else:
if isinstance(hu, dict):
for key, val in hu.items():
self._node_frame[key][u] = val
self._node_frame[key] = F.scatter_row(self._node_frame[key], u, val)
else:
self._node_frame[__REPR__][u] = hu
self._node_frame[__REPR__] = F.scatter_row(self._node_frame[__REPR__], u, hu)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
......@@ -214,9 +214,9 @@ class DGLGraph(DiGraph):
eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key][eid] = val
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else:
self._edge_frame[__REPR__][eid] = h_uv
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id.
......@@ -249,9 +249,9 @@ class DGLGraph(DiGraph):
else:
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key][eid] = val
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else:
self._edge_frame[__REPR__][eid] = h_uv
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation.
......
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph
D = 5
reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
......@@ -20,7 +26,7 @@ def update_func(node, accum):
assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum}
def generate_graph():
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
......@@ -30,7 +36,7 @@ def generate_graph():
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = th.randn(10, D)
col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr({'h' : col})
return g
......@@ -112,6 +118,18 @@ def test_batch_setter_getter():
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
......@@ -180,6 +198,7 @@ def test_update_routines():
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph, __REPR__
D = 32
reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(hu, e_uv):
assert len(hu.shape) == 2
assert hu.shape[1] == D
......@@ -19,7 +25,7 @@ def update_func(hv, accum):
assert hv.shape == accum.shape
return hv + accum
def generate_graph():
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
......@@ -29,7 +35,7 @@ def generate_graph():
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = th.randn(10, D)
col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col)
return g
......@@ -111,6 +117,18 @@ def test_batch_setter_getter():
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr(hh, v)
h2 = g.get_n_repr()
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(hu, edge):
......@@ -178,6 +196,8 @@ def test_update_routines():
reduce_msg_shapes.clear()
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment