"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "9e3a04c3327972b041e3f9aefb7811b3f9b172b8"
Unverified Commit 4673b96f authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

fix set_n_repr grad problem (#38)

parent 0a78dbe1
...@@ -134,9 +134,9 @@ class DGLGraph(DiGraph): ...@@ -134,9 +134,9 @@ class DGLGraph(DiGraph):
else: else:
if isinstance(hu, dict): if isinstance(hu, dict):
for key, val in hu.items(): for key, val in hu.items():
self._node_frame[key][u] = val self._node_frame[key] = F.scatter_row(self._node_frame[key], u, val)
else: else:
self._node_frame[__REPR__][u] = hu self._node_frame[__REPR__] = F.scatter_row(self._node_frame[__REPR__], u, hu)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -214,9 +214,9 @@ class DGLGraph(DiGraph): ...@@ -214,9 +214,9 @@ class DGLGraph(DiGraph):
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key][eid] = val self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else: else:
self._edge_frame[__REPR__][eid] = h_uv self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
...@@ -249,9 +249,9 @@ class DGLGraph(DiGraph): ...@@ -249,9 +249,9 @@ class DGLGraph(DiGraph):
else: else:
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key][eid] = val self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else: else:
self._edge_frame[__REPR__][eid] = h_uv self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
......
import torch as th import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
D = 5 D = 5
reduce_msg_shapes = set() reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(src, edge):
assert len(src['h'].shape) == 2 assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D assert src['h'].shape[1] == D
...@@ -20,7 +26,7 @@ def update_func(node, accum): ...@@ -20,7 +26,7 @@ def update_func(node, accum):
assert node['h'].shape == accum.shape assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum} return {'h' : node['h'] + accum}
def generate_graph(): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): for i in range(10):
g.add_node(i) # 10 nodes. g.add_node(i) # 10 nodes.
...@@ -30,7 +36,7 @@ def generate_graph(): ...@@ -30,7 +36,7 @@ def generate_graph():
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
col = th.randn(10, D) col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr({'h' : col}) g.set_n_repr({'h' : col})
return g return g
...@@ -112,6 +118,18 @@ def test_batch_setter_getter(): ...@@ -112,6 +118,18 @@ def test_batch_setter_getter():
v = th.tensor([3, 4, 5]) v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.] assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(src, edge): def _fmsg(src, edge):
...@@ -180,6 +198,7 @@ def test_update_routines(): ...@@ -180,6 +198,7 @@ def test_update_routines():
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_routines() test_update_routines()
import torch as th import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph, __REPR__ from dgl.graph import DGLGraph, __REPR__
D = 32 D = 32
reduce_msg_shapes = set() reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(hu, e_uv): def message_func(hu, e_uv):
assert len(hu.shape) == 2 assert len(hu.shape) == 2
assert hu.shape[1] == D assert hu.shape[1] == D
...@@ -19,7 +25,7 @@ def update_func(hv, accum): ...@@ -19,7 +25,7 @@ def update_func(hv, accum):
assert hv.shape == accum.shape assert hv.shape == accum.shape
return hv + accum return hv + accum
def generate_graph(): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): for i in range(10):
g.add_node(i) # 10 nodes. g.add_node(i) # 10 nodes.
...@@ -29,7 +35,7 @@ def generate_graph(): ...@@ -29,7 +35,7 @@ def generate_graph():
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
col = th.randn(10, D) col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col) g.set_n_repr(col)
return g return g
...@@ -111,6 +117,18 @@ def test_batch_setter_getter(): ...@@ -111,6 +117,18 @@ def test_batch_setter_getter():
v = th.tensor([3, 4, 5]) v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.] assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr(hh, v)
h2 = g.get_n_repr()
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(hu, edge): def _fmsg(hu, edge):
...@@ -178,6 +196,8 @@ def test_update_routines(): ...@@ -178,6 +196,8 @@ def test_update_routines():
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_routines() test_update_routines()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment